From 461fed1c5c81dca33c1cf2b027139278890d6bcf Mon Sep 17 00:00:00 2001 From: Vinicius Brasil Date: Tue, 9 Jun 2026 21:14:58 -0700 Subject: [PATCH] Simplify flow condition evaluation to be stateless per event Re-evaluate the whole `@listen`/`@router` condition tree against the set of events seen so far, instead of tracking which AND sub-branches remain pending. Net effect: * Fixes a regression where `or_()` short-circuited at the first satisfied branch, leaving a sibling `and_()` half-complete so a later trigger could spuriously re-fire the listener * Removes the fragile per-branch pending state and `id()`-based keys * Shrinks the evaluator to one readable predicate --- .../experimental/conversational_mixin.py | 4 +- lib/crewai/src/crewai/flow/runtime.py | 126 ++++++++---------- lib/crewai/src/crewai/flow/types.py | 2 +- .../tests/agents/test_agent_executor.py | 2 +- lib/crewai/tests/test_flow.py | 69 ++++++---- 5 files changed, 104 insertions(+), 99 deletions(-) diff --git a/lib/crewai/src/crewai/experimental/conversational_mixin.py b/lib/crewai/src/crewai/experimental/conversational_mixin.py index 3801d0570..862706a88 100644 --- a/lib/crewai/src/crewai/experimental/conversational_mixin.py +++ b/lib/crewai/src/crewai/experimental/conversational_mixin.py @@ -84,7 +84,7 @@ class _ConversationalMixin: name: str | None _completed_methods: set[Any] _method_outputs: list[Any] - _pending_and_listeners: dict[Any, Any] + _pending_events: dict[Any, Any] _method_call_counts: dict[Any, int] _is_execution_resuming: bool _pending_user_message: str | dict[str, Any] | None @@ -581,7 +581,7 @@ class _ConversationalMixin: """Clear per-execution tracking so the next turn re-runs the graph.""" self._completed_methods.clear() self._method_outputs.clear() - self._pending_and_listeners.clear() + self._pending_events.clear() self._method_call_counts.clear() self._clear_or_listeners() self._is_execution_resuming = False diff --git a/lib/crewai/src/crewai/flow/runtime.py b/lib/crewai/src/crewai/flow/runtime.py index ee3fc1e18..4be128c6f 100644 --- a/lib/crewai/src/crewai/flow/runtime.py +++ b/lib/crewai/src/crewai/flow/runtime.py @@ -154,14 +154,42 @@ ExecutionContext = Any # type: ignore[assignment,misc] logger = logging.getLogger(__name__) +def _condition_branches( + condition: dict[str, Any], +) -> tuple[Literal["and", "or"], list[FlowDefinitionCondition]]: + if "and" in condition: + return "and", condition["and"] + return "or", condition["or"] + + +def _condition_satisfied(condition: FlowDefinitionCondition, events: set[str]) -> bool: + if isinstance(condition, str): + return condition in events + operator, branches = _condition_branches(condition) + combine = all if operator == "and" else any + return combine(_condition_satisfied(branch, events) for branch in branches) + + def _iter_condition_events(condition: FlowDefinitionCondition) -> Iterator[str]: if isinstance(condition, str): yield condition return - sub_conditions = condition["and"] if "and" in condition else condition["or"] - for sub_condition in sub_conditions: - yield from _iter_condition_events(sub_condition) + _, branches = _condition_branches(condition) + for branch in branches: + yield from _iter_condition_events(branch) + + +def _or_alternative_events(condition: FlowDefinitionCondition) -> Iterator[str]: + if isinstance(condition, str): + yield condition + return + + operator, branches = _condition_branches(condition) + if operator != "or": + return + for branch in branches: + yield from _or_alternative_events(branch) def _is_multi_event_or( @@ -170,7 +198,8 @@ def _is_multi_event_or( if isinstance(condition, str): return False - return "or" in condition and len(condition["or"]) > 1 + operator, branches = _condition_branches(condition) + return operator == "or" and len(branches) > 1 def _resolve_persistence(value: Any) -> Any: @@ -864,7 +893,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): _method_execution_counts: dict[FlowMethodName, int] = PrivateAttr( default_factory=dict ) - _pending_and_listeners: dict[PendingListenerKey, set[int]] = PrivateAttr( + _pending_events: dict[PendingListenerKey, set[str]] = PrivateAttr( default_factory=dict ) _fired_or_listeners: set[FlowMethodName] = PrivateAttr(default_factory=set) @@ -1027,11 +1056,8 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): condition = type(self)._start_condition(method_name) if condition is None: return False - return self._evaluate_condition( - condition, - trigger, - method_name, - pending_key_prefix=f"start:{method_name}", + return self._condition_met( + condition, trigger, PendingListenerKey(f"start:{method_name}") ) def _rearm_or_listeners_for_trigger( @@ -1071,6 +1097,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): # Only events that EXCLUSIVELY feed one OR listener race; an event that # also feeds another listener (e.g. an AND) is left alone when a sibling # wins. e.g. @listen(or_(a, b)) on handler -> {frozenset({a, b}): handler}. + # Events nested under an and_() branch (e.g. or_(and_(a, b), c)) are not + # alternatives and never race -- cancelling one would make the AND + # unsatisfiable. racing_groups: dict[frozenset[FlowMethodName], FlowMethodName] = {} listener_conditions: dict[FlowMethodName, FlowDefinitionCondition] = { listener_name: condition @@ -1093,14 +1122,14 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): for listener_name, condition in listener_conditions.items(): if not isinstance(condition, dict): continue - events = events_by_listener[listener_name] - if "or" not in condition or len(events) <= 1: + alternatives = set(_or_alternative_events(condition)) + if len(alternatives) <= 1: continue exclusive_events = { event - for event in events - if listeners_by_event.get(event, set()) == {listener_name} + for event in alternatives + if listeners_by_event[event] == {listener_name} } if len(exclusive_events) > 1: # Racing only applies to method-completion events: each member is @@ -2028,7 +2057,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): # Clear completed methods and outputs for a fresh start self._completed_methods.clear() self._method_outputs.clear() - self._pending_and_listeners.clear() + self._pending_events.clear() self._clear_or_listeners() self._method_call_counts.clear() else: @@ -2725,63 +2754,18 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): else: await self._execute_start_method(method_name) - def _evaluate_condition( + def _condition_met( self, condition: FlowDefinitionCondition, trigger_method: FlowMethodName, - listener_name: FlowMethodName, - pending_key_prefix: str | None = None, + subscription_key: PendingListenerKey, ) -> bool: - if isinstance(condition, str): - return condition == str(trigger_method) - - def _sub_prefix(index: int) -> str | None: - if pending_key_prefix is None: - return None - return f"{pending_key_prefix}:{index}" - - if "or" in condition: - # Evaluate every sub-condition (no short-circuit): a nested and_() - # branch needs the chance to clear its pending state in - # _pending_and_listeners even when an earlier branch already matched. - any_matched = False - for index, sub_condition in enumerate(condition["or"]): - if self._evaluate_condition( - sub_condition, - trigger_method, - listener_name, - pending_key_prefix=_sub_prefix(index), - ): - any_matched = True - return any_matched - - sub_conditions = condition["and"] - pending_key = PendingListenerKey( - pending_key_prefix - if pending_key_prefix is not None - else f"{listener_name}:{id(condition)}" - ) - - if pending_key not in self._pending_and_listeners: - self._pending_and_listeners[pending_key] = set(range(len(sub_conditions))) - - pending_conditions = self._pending_and_listeners[pending_key] - for index, sub_condition in enumerate(sub_conditions): - if index not in pending_conditions: - continue - if self._evaluate_condition( - sub_condition, - trigger_method, - listener_name, - pending_key_prefix=_sub_prefix(index), - ): - pending_conditions.discard(index) - - if not pending_conditions: - self._pending_and_listeners.pop(pending_key, None) - return True - - return False + seen = self._pending_events.setdefault(subscription_key, set()) + seen.add(str(trigger_method)) + if not _condition_satisfied(condition, seen): + return False + del self._pending_events[subscription_key] + return True def _find_triggered_methods( self, trigger_method: FlowMethodName, router_only: bool @@ -2799,10 +2783,8 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): if should_check_fired and listener_name in self._fired_or_listeners: continue - if self._evaluate_condition( - condition, - trigger_method, - listener_name, + if self._condition_met( + condition, trigger_method, PendingListenerKey(str(listener_name)) ): triggered.append(listener_name) if should_check_fired: diff --git a/lib/crewai/src/crewai/flow/types.py b/lib/crewai/src/crewai/flow/types.py index 6230dd49e..d77c777bc 100644 --- a/lib/crewai/src/crewai/flow/types.py +++ b/lib/crewai/src/crewai/flow/types.py @@ -16,7 +16,7 @@ R = TypeVar("R", covariant=True) FlowMethodName = NewType("FlowMethodName", str) PendingListenerKey = NewType( "PendingListenerKey", - Annotated[str, "nested flow conditions use 'listener_name:object_id'"], + Annotated[str, "listener method name, or 'start:' for conditional starts"], ) diff --git a/lib/crewai/tests/agents/test_agent_executor.py b/lib/crewai/tests/agents/test_agent_executor.py index 5868a7ce2..b22bee401 100644 --- a/lib/crewai/tests/agents/test_agent_executor.py +++ b/lib/crewai/tests/agents/test_agent_executor.py @@ -32,7 +32,7 @@ def _build_executor(**kwargs: Any) -> AgentExecutor: executor._method_outputs = [] executor._completed_methods = set() executor._fired_or_listeners = set() - executor._pending_and_listeners = {} + executor._pending_events = {} executor._method_execution_counts = {} executor._method_call_counts = {} executor._event_futures = [] diff --git a/lib/crewai/tests/test_flow.py b/lib/crewai/tests/test_flow.py index 9e061f813..27a62f5a2 100644 --- a/lib/crewai/tests/test_flow.py +++ b/lib/crewai/tests/test_flow.py @@ -1542,40 +1542,63 @@ def test_deeply_nested_conditions(): def test_or_branch_does_not_leave_stale_and_state(): - """or_() over nested and_() branches must not leave stale pending AND state. - - Regression: evaluating an or_() condition stopped at the first branch that was - satisfied, so a later and_() branch that the *same* trigger would have completed - never cleared its pending state. On the next cycle that trigger alone then - spuriously re-satisfied the whole condition. Both branches share the final - event ``x`` here, so the shared trigger that completes branch ``(a AND x)`` also - completes branch ``(c AND x)`` and both must be cleared together. - """ + fired = [] class StaleStateFlow(Flow): @start() def begin(self): pass - @listen(or_(and_("a", "x"), and_("c", "x"))) - def joined(self): + @listen(begin) + def a(self): pass - flow = StaleStateFlow() - condition = type(flow)._listen_condition("joined") + @listen(begin) + def c(self): + pass - def fires(trigger): - return flow._evaluate_condition(condition, trigger, "joined") + @listen(and_(a, c)) + def x(self): + pass - # First cycle: "a" then "c" arrive, then the shared "x" completes (a AND x). - assert fires("a") is False - assert fires("c") is False - assert fires("x") is True + @listen(or_(and_("a", "x"), and_("c", "y"))) + def joined(self): + fired.append("joined") - # Next cycle: "x" alone must NOT re-satisfy the condition. The "c" from the - # previous cycle was consumed when "joined" fired, so neither branch is half - # complete and "x" by itself is insufficient. - assert fires("x") is False + @router(joined) + def emit_y(self): + return "y" + + StaleStateFlow().kickoff() + + assert fired == ["joined"] + + +def test_and_branch_inside_or_does_not_race(): + execution_order = [] + + class DiamondWithFallbackFlow(Flow): + @start() + def go(self): + execution_order.append("go") + + @listen(go) + def a(self): + execution_order.append("a") + + @listen(go) + def b(self): + execution_order.append("b") + + @listen(or_(and_(a, b), "fallback")) + def done(self): + execution_order.append("done") + + DiamondWithFallbackFlow().kickoff() + + assert "done" in execution_order + assert execution_order.index("done") > execution_order.index("a") + assert execution_order.index("done") > execution_order.index("b") def test_mixed_sync_async_execution_order():