fix(flow): re-arm multi-source or_ listeners across router-driven cycles
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Nightly Canary Release / Check for new commits (push) Has been cancelled
Nightly Canary Release / Build nightly packages (push) Has been cancelled
Nightly Canary Release / Publish nightly to PyPI (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

The previous discard-after-body approach cleared the gate mid-wave, so
a slow parallel @start finishing after the listener body could re-fire
the same multi-source or_ listener. Re-arm only when a router emits a
signal that matches the listener's condition; parallel @start paths
never reach that branch and the race gate keeps protecting them.

Closes #5972
This commit is contained in:
Greyson LaLonde
2026-06-01 15:24:58 -07:00
committed by GitHub
parent 1aba9fe415
commit e53a676c04
2 changed files with 135 additions and 6 deletions

View File

@@ -918,6 +918,52 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
with self._or_listeners_lock:
self._fired_or_listeners.discard(listener_name)
def _rearm_or_listeners_for_trigger(
self,
trigger: FlowMethodName,
rearmable: set[FlowMethodName] | None = None,
) -> None:
"""Re-arm fired OR listeners whose condition includes ``trigger``.
Called when a router emits a fresh signal so cyclic flows can re-fire
multi-source ``or_`` listeners. Listeners whose condition does not
reference the trigger are left fired.
Args:
trigger: The signal/method name a router just emitted.
rearmable: Optional set restricting which listeners may be re-armed.
When provided, listeners outside this set are skipped, and any
listener re-armed is removed from it.
"""
with self._or_listeners_lock:
if not self._fired_or_listeners:
return
candidates: set[FlowMethodName] = (
self._fired_or_listeners & rearmable
if rearmable is not None
else set(self._fired_or_listeners)
)
if not candidates:
return
trigger_str = str(trigger)
to_discard: list[FlowMethodName] = []
for listener_name in candidates:
condition_data = self._listeners.get(listener_name)
if condition_data is None:
continue
if is_simple_flow_condition(condition_data):
_, methods = condition_data
if trigger in methods or trigger_str in {str(m) for m in methods}:
to_discard.append(listener_name)
elif is_flow_condition_dict(condition_data):
all_methods = _extract_all_methods_recursive(condition_data)
if trigger_str in {str(m) for m in all_methods}:
to_discard.append(listener_name)
for listener_name in to_discard:
self._fired_or_listeners.discard(listener_name)
if rearmable is not None:
rearmable.discard(listener_name)
def _build_racing_groups(self) -> dict[frozenset[FlowMethodName], FlowMethodName]:
"""Identify groups of methods that race for the same OR listener.
@@ -2488,20 +2534,22 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
else str(router_result)
)
if router_result is not None
else FlowMethodName("") # Update for next iteration of router chain
else FlowMethodName("")
)
# Now execute normal listeners for all router results and the original trigger
all_triggers = [trigger_method, *router_results]
for current_trigger in all_triggers:
if current_trigger: # Skip None results
with self._or_listeners_lock:
rearmable: set[FlowMethodName] = set(self._fired_or_listeners)
for idx, current_trigger in enumerate(all_triggers):
if current_trigger:
if idx > 0 and rearmable:
self._rearm_or_listeners_for_trigger(current_trigger, rearmable)
listeners_triggered = self._find_triggered_methods(
current_trigger, router_only=False
)
if listeners_triggered:
# Determine what result to pass to listeners
# For router outcomes, pass the HumanFeedbackResult if available
listener_result = router_result_to_feedback.get(
str(current_trigger), result
)

View File

@@ -161,6 +161,87 @@ def test_flow_with_or_condition():
)
def test_or_listener_fires_once_across_parallel_starts():
"""Parallel ``@start`` paths feeding ``or_`` must not double-fire the listener."""
fire_count = 0
class ParallelOrFlow(Flow):
@start()
async def fast_start(self):
return "fast"
@start()
async def slow_start(self):
await asyncio.sleep(0.2)
return "slow"
@listen(or_(fast_start, slow_start))
def handler(self):
nonlocal fire_count
fire_count += 1
asyncio.run(ParallelOrFlow().kickoff_async())
assert fire_count == 1
def test_or_listener_re_arms_across_router_loop():
"""Regression for #5972: multi-source ``or_`` re-fires on each router emission."""
fire_count = 0
class CyclicOrFlow(Flow):
iteration = 0
@start()
def kick(self):
return "kick"
@router(kick)
def initial_router(self):
return "SignalA"
@listen(or_("SignalA", "SignalB"))
def handler(self):
nonlocal fire_count
fire_count += 1
@router(handler)
def loop_router(self):
self.iteration += 1
return "stop" if self.iteration >= 3 else "SignalB"
CyclicOrFlow().kickoff()
assert fire_count == 3
def test_or_listener_does_not_double_fire_across_chained_routers():
"""Chained routers within one dispatch wave must not re-fire the same ``or_`` listener."""
fire_count = 0
class ChainedRouterOrFlow(Flow):
@start()
def kick(self):
return "kick"
@router(kick)
def router_a(self):
return "SignalA"
@router("SignalA")
def router_b(self):
return "SignalB"
@listen(or_("SignalA", "SignalB"))
def handler(self):
nonlocal fire_count
fire_count += 1
ChainedRouterOrFlow().kickoff()
assert fire_count == 1
def test_flow_with_router():
"""Test a flow that uses a router method to determine the next step."""
execution_order = []