mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 06:08:15 +00:00
fix(flow): re-arm multi-source or_ listeners across router-driven cycles
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Nightly Canary Release / Check for new commits (push) Has been cancelled
Nightly Canary Release / Build nightly packages (push) Has been cancelled
Nightly Canary Release / Publish nightly to PyPI (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Nightly Canary Release / Check for new commits (push) Has been cancelled
Nightly Canary Release / Build nightly packages (push) Has been cancelled
Nightly Canary Release / Publish nightly to PyPI (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
The previous discard-after-body approach cleared the gate mid-wave, so a slow parallel @start finishing after the listener body could re-fire the same multi-source or_ listener. Re-arm only when a router emits a signal that matches the listener's condition; parallel @start paths never reach that branch and the race gate keeps protecting them. Closes #5972
This commit is contained in:
@@ -918,6 +918,52 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
with self._or_listeners_lock:
|
||||
self._fired_or_listeners.discard(listener_name)
|
||||
|
||||
def _rearm_or_listeners_for_trigger(
|
||||
self,
|
||||
trigger: FlowMethodName,
|
||||
rearmable: set[FlowMethodName] | None = None,
|
||||
) -> None:
|
||||
"""Re-arm fired OR listeners whose condition includes ``trigger``.
|
||||
|
||||
Called when a router emits a fresh signal so cyclic flows can re-fire
|
||||
multi-source ``or_`` listeners. Listeners whose condition does not
|
||||
reference the trigger are left fired.
|
||||
|
||||
Args:
|
||||
trigger: The signal/method name a router just emitted.
|
||||
rearmable: Optional set restricting which listeners may be re-armed.
|
||||
When provided, listeners outside this set are skipped, and any
|
||||
listener re-armed is removed from it.
|
||||
"""
|
||||
with self._or_listeners_lock:
|
||||
if not self._fired_or_listeners:
|
||||
return
|
||||
candidates: set[FlowMethodName] = (
|
||||
self._fired_or_listeners & rearmable
|
||||
if rearmable is not None
|
||||
else set(self._fired_or_listeners)
|
||||
)
|
||||
if not candidates:
|
||||
return
|
||||
trigger_str = str(trigger)
|
||||
to_discard: list[FlowMethodName] = []
|
||||
for listener_name in candidates:
|
||||
condition_data = self._listeners.get(listener_name)
|
||||
if condition_data is None:
|
||||
continue
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
if trigger in methods or trigger_str in {str(m) for m in methods}:
|
||||
to_discard.append(listener_name)
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
all_methods = _extract_all_methods_recursive(condition_data)
|
||||
if trigger_str in {str(m) for m in all_methods}:
|
||||
to_discard.append(listener_name)
|
||||
for listener_name in to_discard:
|
||||
self._fired_or_listeners.discard(listener_name)
|
||||
if rearmable is not None:
|
||||
rearmable.discard(listener_name)
|
||||
|
||||
def _build_racing_groups(self) -> dict[frozenset[FlowMethodName], FlowMethodName]:
|
||||
"""Identify groups of methods that race for the same OR listener.
|
||||
|
||||
@@ -2488,20 +2534,22 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
else str(router_result)
|
||||
)
|
||||
if router_result is not None
|
||||
else FlowMethodName("") # Update for next iteration of router chain
|
||||
else FlowMethodName("")
|
||||
)
|
||||
|
||||
# Now execute normal listeners for all router results and the original trigger
|
||||
all_triggers = [trigger_method, *router_results]
|
||||
|
||||
for current_trigger in all_triggers:
|
||||
if current_trigger: # Skip None results
|
||||
with self._or_listeners_lock:
|
||||
rearmable: set[FlowMethodName] = set(self._fired_or_listeners)
|
||||
|
||||
for idx, current_trigger in enumerate(all_triggers):
|
||||
if current_trigger:
|
||||
if idx > 0 and rearmable:
|
||||
self._rearm_or_listeners_for_trigger(current_trigger, rearmable)
|
||||
listeners_triggered = self._find_triggered_methods(
|
||||
current_trigger, router_only=False
|
||||
)
|
||||
if listeners_triggered:
|
||||
# Determine what result to pass to listeners
|
||||
# For router outcomes, pass the HumanFeedbackResult if available
|
||||
listener_result = router_result_to_feedback.get(
|
||||
str(current_trigger), result
|
||||
)
|
||||
|
||||
@@ -161,6 +161,87 @@ def test_flow_with_or_condition():
|
||||
)
|
||||
|
||||
|
||||
def test_or_listener_fires_once_across_parallel_starts():
|
||||
"""Parallel ``@start`` paths feeding ``or_`` must not double-fire the listener."""
|
||||
fire_count = 0
|
||||
|
||||
class ParallelOrFlow(Flow):
|
||||
@start()
|
||||
async def fast_start(self):
|
||||
return "fast"
|
||||
|
||||
@start()
|
||||
async def slow_start(self):
|
||||
await asyncio.sleep(0.2)
|
||||
return "slow"
|
||||
|
||||
@listen(or_(fast_start, slow_start))
|
||||
def handler(self):
|
||||
nonlocal fire_count
|
||||
fire_count += 1
|
||||
|
||||
asyncio.run(ParallelOrFlow().kickoff_async())
|
||||
|
||||
assert fire_count == 1
|
||||
|
||||
|
||||
def test_or_listener_re_arms_across_router_loop():
|
||||
"""Regression for #5972: multi-source ``or_`` re-fires on each router emission."""
|
||||
fire_count = 0
|
||||
|
||||
class CyclicOrFlow(Flow):
|
||||
iteration = 0
|
||||
|
||||
@start()
|
||||
def kick(self):
|
||||
return "kick"
|
||||
|
||||
@router(kick)
|
||||
def initial_router(self):
|
||||
return "SignalA"
|
||||
|
||||
@listen(or_("SignalA", "SignalB"))
|
||||
def handler(self):
|
||||
nonlocal fire_count
|
||||
fire_count += 1
|
||||
|
||||
@router(handler)
|
||||
def loop_router(self):
|
||||
self.iteration += 1
|
||||
return "stop" if self.iteration >= 3 else "SignalB"
|
||||
|
||||
CyclicOrFlow().kickoff()
|
||||
|
||||
assert fire_count == 3
|
||||
|
||||
|
||||
def test_or_listener_does_not_double_fire_across_chained_routers():
|
||||
"""Chained routers within one dispatch wave must not re-fire the same ``or_`` listener."""
|
||||
fire_count = 0
|
||||
|
||||
class ChainedRouterOrFlow(Flow):
|
||||
@start()
|
||||
def kick(self):
|
||||
return "kick"
|
||||
|
||||
@router(kick)
|
||||
def router_a(self):
|
||||
return "SignalA"
|
||||
|
||||
@router("SignalA")
|
||||
def router_b(self):
|
||||
return "SignalB"
|
||||
|
||||
@listen(or_("SignalA", "SignalB"))
|
||||
def handler(self):
|
||||
nonlocal fire_count
|
||||
fire_count += 1
|
||||
|
||||
ChainedRouterOrFlow().kickoff()
|
||||
|
||||
assert fire_count == 1
|
||||
|
||||
|
||||
def test_flow_with_router():
|
||||
"""Test a flow that uses a router method to determine the next step."""
|
||||
execution_order = []
|
||||
|
||||
Reference in New Issue
Block a user