diff --git a/lib/crewai/src/crewai/flow/runtime.py b/lib/crewai/src/crewai/flow/runtime.py index 65efb2900..31567ed1a 100644 --- a/lib/crewai/src/crewai/flow/runtime.py +++ b/lib/crewai/src/crewai/flow/runtime.py @@ -918,6 +918,52 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): with self._or_listeners_lock: self._fired_or_listeners.discard(listener_name) + def _rearm_or_listeners_for_trigger( + self, + trigger: FlowMethodName, + rearmable: set[FlowMethodName] | None = None, + ) -> None: + """Re-arm fired OR listeners whose condition includes ``trigger``. + + Called when a router emits a fresh signal so cyclic flows can re-fire + multi-source ``or_`` listeners. Listeners whose condition does not + reference the trigger are left fired. + + Args: + trigger: The signal/method name a router just emitted. + rearmable: Optional set restricting which listeners may be re-armed. + When provided, listeners outside this set are skipped, and any + listener re-armed is removed from it. + """ + with self._or_listeners_lock: + if not self._fired_or_listeners: + return + candidates: set[FlowMethodName] = ( + self._fired_or_listeners & rearmable + if rearmable is not None + else set(self._fired_or_listeners) + ) + if not candidates: + return + trigger_str = str(trigger) + to_discard: list[FlowMethodName] = [] + for listener_name in candidates: + condition_data = self._listeners.get(listener_name) + if condition_data is None: + continue + if is_simple_flow_condition(condition_data): + _, methods = condition_data + if trigger in methods or trigger_str in {str(m) for m in methods}: + to_discard.append(listener_name) + elif is_flow_condition_dict(condition_data): + all_methods = _extract_all_methods_recursive(condition_data) + if trigger_str in {str(m) for m in all_methods}: + to_discard.append(listener_name) + for listener_name in to_discard: + self._fired_or_listeners.discard(listener_name) + if rearmable is not None: + rearmable.discard(listener_name) + def _build_racing_groups(self) -> dict[frozenset[FlowMethodName], FlowMethodName]: """Identify groups of methods that race for the same OR listener. @@ -2488,20 +2534,22 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): else str(router_result) ) if router_result is not None - else FlowMethodName("") # Update for next iteration of router chain + else FlowMethodName("") ) - # Now execute normal listeners for all router results and the original trigger all_triggers = [trigger_method, *router_results] - for current_trigger in all_triggers: - if current_trigger: # Skip None results + with self._or_listeners_lock: + rearmable: set[FlowMethodName] = set(self._fired_or_listeners) + + for idx, current_trigger in enumerate(all_triggers): + if current_trigger: + if idx > 0 and rearmable: + self._rearm_or_listeners_for_trigger(current_trigger, rearmable) listeners_triggered = self._find_triggered_methods( current_trigger, router_only=False ) if listeners_triggered: - # Determine what result to pass to listeners - # For router outcomes, pass the HumanFeedbackResult if available listener_result = router_result_to_feedback.get( str(current_trigger), result ) diff --git a/lib/crewai/tests/test_flow.py b/lib/crewai/tests/test_flow.py index 56c14b85e..bc9a4ab87 100644 --- a/lib/crewai/tests/test_flow.py +++ b/lib/crewai/tests/test_flow.py @@ -161,6 +161,87 @@ def test_flow_with_or_condition(): ) +def test_or_listener_fires_once_across_parallel_starts(): + """Parallel ``@start`` paths feeding ``or_`` must not double-fire the listener.""" + fire_count = 0 + + class ParallelOrFlow(Flow): + @start() + async def fast_start(self): + return "fast" + + @start() + async def slow_start(self): + await asyncio.sleep(0.2) + return "slow" + + @listen(or_(fast_start, slow_start)) + def handler(self): + nonlocal fire_count + fire_count += 1 + + asyncio.run(ParallelOrFlow().kickoff_async()) + + assert fire_count == 1 + + +def test_or_listener_re_arms_across_router_loop(): + """Regression for #5972: multi-source ``or_`` re-fires on each router emission.""" + fire_count = 0 + + class CyclicOrFlow(Flow): + iteration = 0 + + @start() + def kick(self): + return "kick" + + @router(kick) + def initial_router(self): + return "SignalA" + + @listen(or_("SignalA", "SignalB")) + def handler(self): + nonlocal fire_count + fire_count += 1 + + @router(handler) + def loop_router(self): + self.iteration += 1 + return "stop" if self.iteration >= 3 else "SignalB" + + CyclicOrFlow().kickoff() + + assert fire_count == 3 + + +def test_or_listener_does_not_double_fire_across_chained_routers(): + """Chained routers within one dispatch wave must not re-fire the same ``or_`` listener.""" + fire_count = 0 + + class ChainedRouterOrFlow(Flow): + @start() + def kick(self): + return "kick" + + @router(kick) + def router_a(self): + return "SignalA" + + @router("SignalA") + def router_b(self): + return "SignalB" + + @listen(or_("SignalA", "SignalB")) + def handler(self): + nonlocal fire_count + fire_count += 1 + + ChainedRouterOrFlow().kickoff() + + assert fire_count == 1 + + def test_flow_with_router(): """Test a flow that uses a router method to determine the next step.""" execution_order = []