mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 14:09:24 +00:00
Simplify flow condition evaluation to be stateless per event (#6097)
Re-evaluate the whole `@listen`/`@router` condition tree against the set of events seen so far, instead of tracking which AND sub-branches remain pending. Net effect: * Fixes a regression where `or_()` short-circuited at the first satisfied branch, leaving a sibling `and_()` half-complete so a later trigger could spuriously re-fire the listener * Removes the fragile per-branch pending state and `id()`-based keys * Shrinks the evaluator to one readable predicate
This commit is contained in:
@@ -84,7 +84,7 @@ class _ConversationalMixin:
|
||||
name: str | None
|
||||
_completed_methods: set[Any]
|
||||
_method_outputs: list[Any]
|
||||
_pending_and_listeners: dict[Any, Any]
|
||||
_pending_events: dict[Any, Any]
|
||||
_method_call_counts: dict[Any, int]
|
||||
_is_execution_resuming: bool
|
||||
_pending_user_message: str | dict[str, Any] | None
|
||||
@@ -581,7 +581,7 @@ class _ConversationalMixin:
|
||||
"""Clear per-execution tracking so the next turn re-runs the graph."""
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
self._pending_and_listeners.clear()
|
||||
self._pending_events.clear()
|
||||
self._method_call_counts.clear()
|
||||
self._clear_or_listeners()
|
||||
self._is_execution_resuming = False
|
||||
|
||||
@@ -154,14 +154,42 @@ ExecutionContext = Any # type: ignore[assignment,misc]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _condition_branches(
|
||||
condition: dict[str, Any],
|
||||
) -> tuple[Literal["and", "or"], list[FlowDefinitionCondition]]:
|
||||
if "and" in condition:
|
||||
return "and", condition["and"]
|
||||
return "or", condition["or"]
|
||||
|
||||
|
||||
def _condition_satisfied(condition: FlowDefinitionCondition, events: set[str]) -> bool:
|
||||
if isinstance(condition, str):
|
||||
return condition in events
|
||||
operator, branches = _condition_branches(condition)
|
||||
combine = all if operator == "and" else any
|
||||
return combine(_condition_satisfied(branch, events) for branch in branches)
|
||||
|
||||
|
||||
def _iter_condition_events(condition: FlowDefinitionCondition) -> Iterator[str]:
|
||||
if isinstance(condition, str):
|
||||
yield condition
|
||||
return
|
||||
|
||||
sub_conditions = condition["and"] if "and" in condition else condition["or"]
|
||||
for sub_condition in sub_conditions:
|
||||
yield from _iter_condition_events(sub_condition)
|
||||
_, branches = _condition_branches(condition)
|
||||
for branch in branches:
|
||||
yield from _iter_condition_events(branch)
|
||||
|
||||
|
||||
def _or_alternative_events(condition: FlowDefinitionCondition) -> Iterator[str]:
|
||||
if isinstance(condition, str):
|
||||
yield condition
|
||||
return
|
||||
|
||||
operator, branches = _condition_branches(condition)
|
||||
if operator != "or":
|
||||
return
|
||||
for branch in branches:
|
||||
yield from _or_alternative_events(branch)
|
||||
|
||||
|
||||
def _is_multi_event_or(
|
||||
@@ -170,7 +198,8 @@ def _is_multi_event_or(
|
||||
if isinstance(condition, str):
|
||||
return False
|
||||
|
||||
return "or" in condition and len(condition["or"]) > 1
|
||||
operator, branches = _condition_branches(condition)
|
||||
return operator == "or" and len(branches) > 1
|
||||
|
||||
|
||||
def _resolve_persistence(value: Any) -> Any:
|
||||
@@ -864,7 +893,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
_method_execution_counts: dict[FlowMethodName, int] = PrivateAttr(
|
||||
default_factory=dict
|
||||
)
|
||||
_pending_and_listeners: dict[PendingListenerKey, set[int]] = PrivateAttr(
|
||||
_pending_events: dict[PendingListenerKey, set[str]] = PrivateAttr(
|
||||
default_factory=dict
|
||||
)
|
||||
_fired_or_listeners: set[FlowMethodName] = PrivateAttr(default_factory=set)
|
||||
@@ -1027,11 +1056,8 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
condition = type(self)._start_condition(method_name)
|
||||
if condition is None:
|
||||
return False
|
||||
return self._evaluate_condition(
|
||||
condition,
|
||||
trigger,
|
||||
method_name,
|
||||
pending_key_prefix=f"start:{method_name}",
|
||||
return self._condition_met(
|
||||
condition, trigger, PendingListenerKey(f"start:{method_name}")
|
||||
)
|
||||
|
||||
def _rearm_or_listeners_for_trigger(
|
||||
@@ -1071,6 +1097,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
# Only events that EXCLUSIVELY feed one OR listener race; an event that
|
||||
# also feeds another listener (e.g. an AND) is left alone when a sibling
|
||||
# wins. e.g. @listen(or_(a, b)) on handler -> {frozenset({a, b}): handler}.
|
||||
# Events nested under an and_() branch (e.g. or_(and_(a, b), c)) are not
|
||||
# alternatives and never race -- cancelling one would make the AND
|
||||
# unsatisfiable.
|
||||
racing_groups: dict[frozenset[FlowMethodName], FlowMethodName] = {}
|
||||
listener_conditions: dict[FlowMethodName, FlowDefinitionCondition] = {
|
||||
listener_name: condition
|
||||
@@ -1093,14 +1122,14 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
for listener_name, condition in listener_conditions.items():
|
||||
if not isinstance(condition, dict):
|
||||
continue
|
||||
events = events_by_listener[listener_name]
|
||||
if "or" not in condition or len(events) <= 1:
|
||||
alternatives = set(_or_alternative_events(condition))
|
||||
if len(alternatives) <= 1:
|
||||
continue
|
||||
|
||||
exclusive_events = {
|
||||
event
|
||||
for event in events
|
||||
if listeners_by_event.get(event, set()) == {listener_name}
|
||||
for event in alternatives
|
||||
if listeners_by_event[event] == {listener_name}
|
||||
}
|
||||
if len(exclusive_events) > 1:
|
||||
# Racing only applies to method-completion events: each member is
|
||||
@@ -2028,7 +2057,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
# Clear completed methods and outputs for a fresh start
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
self._pending_and_listeners.clear()
|
||||
self._pending_events.clear()
|
||||
self._clear_or_listeners()
|
||||
self._method_call_counts.clear()
|
||||
else:
|
||||
@@ -2725,63 +2754,18 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
await self._execute_start_method(method_name)
|
||||
|
||||
def _evaluate_condition(
|
||||
def _condition_met(
|
||||
self,
|
||||
condition: FlowDefinitionCondition,
|
||||
trigger_method: FlowMethodName,
|
||||
listener_name: FlowMethodName,
|
||||
pending_key_prefix: str | None = None,
|
||||
subscription_key: PendingListenerKey,
|
||||
) -> bool:
|
||||
if isinstance(condition, str):
|
||||
return condition == str(trigger_method)
|
||||
|
||||
def _sub_prefix(index: int) -> str | None:
|
||||
if pending_key_prefix is None:
|
||||
return None
|
||||
return f"{pending_key_prefix}:{index}"
|
||||
|
||||
if "or" in condition:
|
||||
# Evaluate every sub-condition (no short-circuit): a nested and_()
|
||||
# branch needs the chance to clear its pending state in
|
||||
# _pending_and_listeners even when an earlier branch already matched.
|
||||
any_matched = False
|
||||
for index, sub_condition in enumerate(condition["or"]):
|
||||
if self._evaluate_condition(
|
||||
sub_condition,
|
||||
trigger_method,
|
||||
listener_name,
|
||||
pending_key_prefix=_sub_prefix(index),
|
||||
):
|
||||
any_matched = True
|
||||
return any_matched
|
||||
|
||||
sub_conditions = condition["and"]
|
||||
pending_key = PendingListenerKey(
|
||||
pending_key_prefix
|
||||
if pending_key_prefix is not None
|
||||
else f"{listener_name}:{id(condition)}"
|
||||
)
|
||||
|
||||
if pending_key not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[pending_key] = set(range(len(sub_conditions)))
|
||||
|
||||
pending_conditions = self._pending_and_listeners[pending_key]
|
||||
for index, sub_condition in enumerate(sub_conditions):
|
||||
if index not in pending_conditions:
|
||||
continue
|
||||
if self._evaluate_condition(
|
||||
sub_condition,
|
||||
trigger_method,
|
||||
listener_name,
|
||||
pending_key_prefix=_sub_prefix(index),
|
||||
):
|
||||
pending_conditions.discard(index)
|
||||
|
||||
if not pending_conditions:
|
||||
self._pending_and_listeners.pop(pending_key, None)
|
||||
return True
|
||||
|
||||
return False
|
||||
seen = self._pending_events.setdefault(subscription_key, set())
|
||||
seen.add(str(trigger_method))
|
||||
if not _condition_satisfied(condition, seen):
|
||||
return False
|
||||
del self._pending_events[subscription_key]
|
||||
return True
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: FlowMethodName, router_only: bool
|
||||
@@ -2799,10 +2783,8 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
if should_check_fired and listener_name in self._fired_or_listeners:
|
||||
continue
|
||||
|
||||
if self._evaluate_condition(
|
||||
condition,
|
||||
trigger_method,
|
||||
listener_name,
|
||||
if self._condition_met(
|
||||
condition, trigger_method, PendingListenerKey(str(listener_name))
|
||||
):
|
||||
triggered.append(listener_name)
|
||||
if should_check_fired:
|
||||
|
||||
@@ -16,7 +16,7 @@ R = TypeVar("R", covariant=True)
|
||||
FlowMethodName = NewType("FlowMethodName", str)
|
||||
PendingListenerKey = NewType(
|
||||
"PendingListenerKey",
|
||||
Annotated[str, "nested flow conditions use 'listener_name:object_id'"],
|
||||
Annotated[str, "listener method name, or 'start:<method>' for conditional starts"],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ def _build_executor(**kwargs: Any) -> AgentExecutor:
|
||||
executor._method_outputs = []
|
||||
executor._completed_methods = set()
|
||||
executor._fired_or_listeners = set()
|
||||
executor._pending_and_listeners = {}
|
||||
executor._pending_events = {}
|
||||
executor._method_execution_counts = {}
|
||||
executor._method_call_counts = {}
|
||||
executor._event_futures = []
|
||||
|
||||
@@ -1542,40 +1542,63 @@ def test_deeply_nested_conditions():
|
||||
|
||||
|
||||
def test_or_branch_does_not_leave_stale_and_state():
|
||||
"""or_() over nested and_() branches must not leave stale pending AND state.
|
||||
|
||||
Regression: evaluating an or_() condition stopped at the first branch that was
|
||||
satisfied, so a later and_() branch that the *same* trigger would have completed
|
||||
never cleared its pending state. On the next cycle that trigger alone then
|
||||
spuriously re-satisfied the whole condition. Both branches share the final
|
||||
event ``x`` here, so the shared trigger that completes branch ``(a AND x)`` also
|
||||
completes branch ``(c AND x)`` and both must be cleared together.
|
||||
"""
|
||||
fired = []
|
||||
|
||||
class StaleStateFlow(Flow):
|
||||
@start()
|
||||
def begin(self):
|
||||
pass
|
||||
|
||||
@listen(or_(and_("a", "x"), and_("c", "x")))
|
||||
def joined(self):
|
||||
@listen(begin)
|
||||
def a(self):
|
||||
pass
|
||||
|
||||
flow = StaleStateFlow()
|
||||
condition = type(flow)._listen_condition("joined")
|
||||
@listen(begin)
|
||||
def c(self):
|
||||
pass
|
||||
|
||||
def fires(trigger):
|
||||
return flow._evaluate_condition(condition, trigger, "joined")
|
||||
@listen(and_(a, c))
|
||||
def x(self):
|
||||
pass
|
||||
|
||||
# First cycle: "a" then "c" arrive, then the shared "x" completes (a AND x).
|
||||
assert fires("a") is False
|
||||
assert fires("c") is False
|
||||
assert fires("x") is True
|
||||
@listen(or_(and_("a", "x"), and_("c", "y")))
|
||||
def joined(self):
|
||||
fired.append("joined")
|
||||
|
||||
# Next cycle: "x" alone must NOT re-satisfy the condition. The "c" from the
|
||||
# previous cycle was consumed when "joined" fired, so neither branch is half
|
||||
# complete and "x" by itself is insufficient.
|
||||
assert fires("x") is False
|
||||
@router(joined)
|
||||
def emit_y(self):
|
||||
return "y"
|
||||
|
||||
StaleStateFlow().kickoff()
|
||||
|
||||
assert fired == ["joined"]
|
||||
|
||||
|
||||
def test_and_branch_inside_or_does_not_race():
|
||||
execution_order = []
|
||||
|
||||
class DiamondWithFallbackFlow(Flow):
|
||||
@start()
|
||||
def go(self):
|
||||
execution_order.append("go")
|
||||
|
||||
@listen(go)
|
||||
def a(self):
|
||||
execution_order.append("a")
|
||||
|
||||
@listen(go)
|
||||
def b(self):
|
||||
execution_order.append("b")
|
||||
|
||||
@listen(or_(and_(a, b), "fallback"))
|
||||
def done(self):
|
||||
execution_order.append("done")
|
||||
|
||||
DiamondWithFallbackFlow().kickoff()
|
||||
|
||||
assert "done" in execution_order
|
||||
assert execution_order.index("done") > execution_order.index("a")
|
||||
assert execution_order.index("done") > execution_order.index("b")
|
||||
|
||||
|
||||
def test_mixed_sync_async_execution_order():
|
||||
|
||||
Reference in New Issue
Block a user