fix: preserve nested condition structure in Flow decorators
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled

Fixes nested boolean conditions being flattened in @listen, @start, and @router decorators. The or_() and and_() combinators now preserve their nested structure using a "conditions" key instead of flattening to a list. Added recursive evaluation logic to properly handle complex patterns like or_(and_(A, B), and_(C, D)).
This commit is contained in:
Greyson LaLonde
2025-10-17 17:06:19 -04:00
committed by GitHub
parent 0229390ad1
commit 42f2b4d551
2 changed files with 275 additions and 71 deletions

View File

@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants from crewai.flow.utils import get_possible_return_constants
from crewai.utilities.printer import Printer from crewai.utilities.printer import Printer, PrinterColor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -105,7 +105,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
condition : Optional[Union[str, dict, Callable]], optional condition : Optional[Union[str, dict, Callable]], optional
Defines when the start method should execute. Can be: Defines when the start method should execute. Can be:
- str: Name of a method that triggers this start - str: Name of a method that triggers this start
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) - dict: Result from or_() or and_(), including nested conditions
- Callable: A method reference that triggers this start - Callable: A method reference that triggers this start
Default is None, meaning unconditional start. Default is None, meaning unconditional start.
@@ -140,13 +140,18 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
if isinstance(condition, str): if isinstance(condition, str):
func.__trigger_methods__ = [condition] func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR" func.__condition_type__ = "OR"
elif ( elif isinstance(condition, dict) and "type" in condition:
isinstance(condition, dict) if "conditions" in condition:
and "type" in condition func.__trigger_condition__ = condition
and "methods" in condition func.__trigger_methods__ = _extract_all_methods(condition)
): func.__condition_type__ = condition["type"]
func.__trigger_methods__ = condition["methods"] elif "methods" in condition:
func.__condition_type__ = condition["type"] func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif callable(condition) and hasattr(condition, "__name__"): elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__] func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR" func.__condition_type__ = "OR"
@@ -172,7 +177,7 @@ def listen(condition: str | dict | Callable) -> Callable:
condition : Union[str, dict, Callable] condition : Union[str, dict, Callable]
Specifies when the listener should execute. Can be: Specifies when the listener should execute. Can be:
- str: Name of a method that triggers this listener - str: Name of a method that triggers this listener
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) - dict: Result from or_() or and_(), including nested conditions
- Callable: A method reference that triggers this listener - Callable: A method reference that triggers this listener
Returns Returns
@@ -200,13 +205,18 @@ def listen(condition: str | dict | Callable) -> Callable:
if isinstance(condition, str): if isinstance(condition, str):
func.__trigger_methods__ = [condition] func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR" func.__condition_type__ = "OR"
elif ( elif isinstance(condition, dict) and "type" in condition:
isinstance(condition, dict) if "conditions" in condition:
and "type" in condition func.__trigger_condition__ = condition
and "methods" in condition func.__trigger_methods__ = _extract_all_methods(condition)
): func.__condition_type__ = condition["type"]
func.__trigger_methods__ = condition["methods"] elif "methods" in condition:
func.__condition_type__ = condition["type"] func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif callable(condition) and hasattr(condition, "__name__"): elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__] func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR" func.__condition_type__ = "OR"
@@ -233,7 +243,7 @@ def router(condition: str | dict | Callable) -> Callable:
condition : Union[str, dict, Callable] condition : Union[str, dict, Callable]
Specifies when the router should execute. Can be: Specifies when the router should execute. Can be:
- str: Name of a method that triggers this router - str: Name of a method that triggers this router
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) - dict: Result from or_() or and_(), including nested conditions
- Callable: A method reference that triggers this router - Callable: A method reference that triggers this router
Returns Returns
@@ -266,13 +276,18 @@ def router(condition: str | dict | Callable) -> Callable:
if isinstance(condition, str): if isinstance(condition, str):
func.__trigger_methods__ = [condition] func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR" func.__condition_type__ = "OR"
elif ( elif isinstance(condition, dict) and "type" in condition:
isinstance(condition, dict) if "conditions" in condition:
and "type" in condition func.__trigger_condition__ = condition
and "methods" in condition func.__trigger_methods__ = _extract_all_methods(condition)
): func.__condition_type__ = condition["type"]
func.__trigger_methods__ = condition["methods"] elif "methods" in condition:
func.__condition_type__ = condition["type"] func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif callable(condition) and hasattr(condition, "__name__"): elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__] func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR" func.__condition_type__ = "OR"
@@ -298,14 +313,15 @@ def or_(*conditions: str | dict | Callable) -> dict:
*conditions : Union[str, dict, Callable] *conditions : Union[str, dict, Callable]
Variable number of conditions that can be: Variable number of conditions that can be:
- str: Method names - str: Method names
- dict: Existing condition dictionaries - dict: Existing condition dictionaries (nested conditions)
- Callable: Method references - Callable: Method references
Returns Returns
------- -------
dict dict
A condition dictionary with format: A condition dictionary with format:
{"type": "OR", "methods": list_of_method_names} {"type": "OR", "conditions": list_of_conditions}
where each condition can be a string (method name) or a nested dict
Raises Raises
------ ------
@@ -317,18 +333,22 @@ def or_(*conditions: str | dict | Callable) -> dict:
>>> @listen(or_("success", "timeout")) >>> @listen(or_("success", "timeout"))
>>> def handle_completion(self): >>> def handle_completion(self):
... pass ... pass
>>> @listen(or_(and_("step1", "step2"), "step3"))
>>> def handle_nested(self):
... pass
""" """
methods = [] processed_conditions: list[str | dict[str, Any]] = []
for condition in conditions: for condition in conditions:
if isinstance(condition, dict) and "methods" in condition: if isinstance(condition, dict):
methods.extend(condition["methods"]) processed_conditions.append(condition)
elif isinstance(condition, str): elif isinstance(condition, str):
methods.append(condition) processed_conditions.append(condition)
elif callable(condition): elif callable(condition):
methods.append(getattr(condition, "__name__", repr(condition))) processed_conditions.append(getattr(condition, "__name__", repr(condition)))
else: else:
raise ValueError("Invalid condition in or_()") raise ValueError("Invalid condition in or_()")
return {"type": "OR", "methods": methods} return {"type": "OR", "conditions": processed_conditions}
def and_(*conditions: str | dict | Callable) -> dict: def and_(*conditions: str | dict | Callable) -> dict:
@@ -344,14 +364,15 @@ def and_(*conditions: str | dict | Callable) -> dict:
*conditions : Union[str, dict, Callable] *conditions : Union[str, dict, Callable]
Variable number of conditions that can be: Variable number of conditions that can be:
- str: Method names - str: Method names
- dict: Existing condition dictionaries - dict: Existing condition dictionaries (nested conditions)
- Callable: Method references - Callable: Method references
Returns Returns
------- -------
dict dict
A condition dictionary with format: A condition dictionary with format:
{"type": "AND", "methods": list_of_method_names} {"type": "AND", "conditions": list_of_conditions}
where each condition can be a string (method name) or a nested dict
Raises Raises
------ ------
@@ -363,18 +384,69 @@ def and_(*conditions: str | dict | Callable) -> dict:
>>> @listen(and_("validated", "processed")) >>> @listen(and_("validated", "processed"))
>>> def handle_complete_data(self): >>> def handle_complete_data(self):
... pass ... pass
>>> @listen(and_(or_("step1", "step2"), "step3"))
>>> def handle_nested(self):
... pass
""" """
methods = [] processed_conditions: list[str | dict[str, Any]] = []
for condition in conditions: for condition in conditions:
if isinstance(condition, dict) and "methods" in condition: if isinstance(condition, dict):
methods.extend(condition["methods"]) processed_conditions.append(condition)
elif isinstance(condition, str): elif isinstance(condition, str):
methods.append(condition) processed_conditions.append(condition)
elif callable(condition): elif callable(condition):
methods.append(getattr(condition, "__name__", repr(condition))) processed_conditions.append(getattr(condition, "__name__", repr(condition)))
else: else:
raise ValueError("Invalid condition in and_()") raise ValueError("Invalid condition in and_()")
return {"type": "AND", "methods": methods} return {"type": "AND", "conditions": processed_conditions}
def _normalize_condition(condition: str | dict | list) -> dict:
"""Normalize a condition to standard format with 'conditions' key.
Args:
condition: Can be a string (method name), dict (condition), or list
Returns:
Normalized dict with 'type' and 'conditions' keys
"""
if isinstance(condition, str):
return {"type": "OR", "conditions": [condition]}
if isinstance(condition, dict):
if "conditions" in condition:
return condition
if "methods" in condition:
return {"type": condition["type"], "conditions": condition["methods"]}
return condition
if isinstance(condition, list):
return {"type": "OR", "conditions": condition}
return {"type": "OR", "conditions": [condition]}
def _extract_all_methods(condition: str | dict | list) -> list[str]:
"""Extract all method names from a condition (including nested).
Args:
condition: Can be a string, dict, or list
Returns:
List of all method names in the condition tree
"""
if isinstance(condition, str):
return [condition]
if isinstance(condition, dict):
normalized = _normalize_condition(condition)
methods = []
for sub_cond in normalized.get("conditions", []):
methods.extend(_extract_all_methods(sub_cond))
return methods
if isinstance(condition, list):
methods = []
for item in condition:
methods.extend(_extract_all_methods(item))
return methods
return []
class FlowMeta(type): class FlowMeta(type):
@@ -402,7 +474,10 @@ class FlowMeta(type):
if hasattr(attr_value, "__trigger_methods__"): if hasattr(attr_value, "__trigger_methods__"):
methods = attr_value.__trigger_methods__ methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR") condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods) if hasattr(attr_value, "__trigger_condition__"):
listeners[attr_name] = attr_value.__trigger_condition__
else:
listeners[attr_name] = (condition_type, methods)
if ( if (
hasattr(attr_value, "__is_router__") hasattr(attr_value, "__is_router__")
@@ -822,6 +897,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Clear completed methods and outputs for a fresh start # Clear completed methods and outputs for a fresh start
self._completed_methods.clear() self._completed_methods.clear()
self._method_outputs.clear() self._method_outputs.clear()
self._pending_and_listeners.clear()
else: else:
# We're restoring from persistence, set the flag # We're restoring from persistence, set the flag
self._is_execution_resuming = True self._is_execution_resuming = True
@@ -1086,10 +1162,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
for method_name in self._start_methods: for method_name in self._start_methods:
# Check if this start method is triggered by the current trigger # Check if this start method is triggered by the current trigger
if method_name in self._listeners: if method_name in self._listeners:
condition_type, trigger_methods = self._listeners[ condition_data = self._listeners[method_name]
method_name should_trigger = False
] if isinstance(condition_data, tuple):
if current_trigger in trigger_methods: _, trigger_methods = condition_data
should_trigger = current_trigger in trigger_methods
elif isinstance(condition_data, dict):
all_methods = _extract_all_methods(condition_data)
should_trigger = current_trigger in all_methods
if should_trigger:
# Only execute if this is a cycle (method was already completed) # Only execute if this is a cycle (method was already completed)
if method_name in self._completed_methods: if method_name in self._completed_methods:
# For router-triggered start methods in cycles, temporarily clear resumption flag # For router-triggered start methods in cycles, temporarily clear resumption flag
@@ -1099,6 +1181,51 @@ class Flow(Generic[T], metaclass=FlowMeta):
await self._execute_start_method(method_name) await self._execute_start_method(method_name)
self._is_execution_resuming = was_resuming self._is_execution_resuming = was_resuming
def _evaluate_condition(
self, condition: str | dict, trigger_method: str, listener_name: str
) -> bool:
"""Recursively evaluate a condition (simple or nested).
Args:
condition: Can be a string (method name) or dict (nested condition)
trigger_method: The method that just completed
listener_name: Name of the listener being evaluated
Returns:
True if the condition is satisfied, False otherwise
"""
if isinstance(condition, str):
return condition == trigger_method
if isinstance(condition, dict):
normalized = _normalize_condition(condition)
cond_type = normalized.get("type", "OR")
sub_conditions = normalized.get("conditions", [])
if cond_type == "OR":
return any(
self._evaluate_condition(sub_cond, trigger_method, listener_name)
for sub_cond in sub_conditions
)
if cond_type == "AND":
pending_key = f"{listener_name}:{id(condition)}"
if pending_key not in self._pending_and_listeners:
all_methods = set(_extract_all_methods(condition))
self._pending_and_listeners[pending_key] = all_methods
if trigger_method in self._pending_and_listeners[pending_key]:
self._pending_and_listeners[pending_key].discard(trigger_method)
if not self._pending_and_listeners[pending_key]:
self._pending_and_listeners.pop(pending_key, None)
return True
return False
return False
def _find_triggered_methods( def _find_triggered_methods(
self, trigger_method: str, router_only: bool self, trigger_method: str, router_only: bool
) -> list[str]: ) -> list[str]:
@@ -1106,7 +1233,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Finds all methods that should be triggered based on conditions. Finds all methods that should be triggered based on conditions.
This internal method evaluates both OR and AND conditions to determine This internal method evaluates both OR and AND conditions to determine
which methods should be executed next in the flow. which methods should be executed next in the flow. Supports nested conditions.
Parameters Parameters
---------- ----------
@@ -1123,14 +1250,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
Notes Notes
----- -----
- Handles both OR and AND conditions: - Handles both OR and AND conditions, including nested combinations
* OR: Triggers if any condition is met
* AND: Triggers only when all conditions are met
- Maintains state for AND conditions using _pending_and_listeners - Maintains state for AND conditions using _pending_and_listeners
- Separates router and normal listener evaluation - Separates router and normal listener evaluation
""" """
triggered = [] triggered = []
for listener_name, (condition_type, methods) in self._listeners.items():
for listener_name, condition_data in self._listeners.items():
is_router = listener_name in self._routers is_router = listener_name in self._routers
if router_only != is_router: if router_only != is_router:
@@ -1139,23 +1265,29 @@ class Flow(Generic[T], metaclass=FlowMeta):
if not router_only and listener_name in self._start_methods: if not router_only and listener_name in self._start_methods:
continue continue
if condition_type == "OR": if isinstance(condition_data, tuple):
# If the trigger_method matches any in methods, run this condition_type, methods = condition_data
if trigger_method in methods:
triggered.append(listener_name)
elif condition_type == "AND":
# Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods
if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(trigger_method)
if not self._pending_and_listeners[listener_name]: if condition_type == "OR":
# All required methods have been executed if trigger_method in methods:
triggered.append(listener_name)
elif condition_type == "AND":
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(
trigger_method
)
if not self._pending_and_listeners[listener_name]:
triggered.append(listener_name)
self._pending_and_listeners.pop(listener_name, None)
elif isinstance(condition_data, dict):
if self._evaluate_condition(
condition_data, trigger_method, listener_name
):
triggered.append(listener_name) triggered.append(listener_name)
# Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None)
return triggered return triggered
@@ -1218,7 +1350,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
raise raise
def _log_flow_event( def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info" self, message: str, color: PrinterColor | None = "yellow", level: str = "info"
) -> None: ) -> None:
"""Centralized logging method for flow events. """Centralized logging method for flow events.

View File

@@ -6,15 +6,15 @@ from datetime import datetime
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.flow_events import ( from crewai.events.types.flow_events import (
FlowFinishedEvent, FlowFinishedEvent,
FlowStartedEvent,
FlowPlotEvent, FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFinishedEvent, MethodExecutionFinishedEvent,
MethodExecutionStartedEvent, MethodExecutionStartedEvent,
) )
from crewai.flow.flow import Flow, and_, listen, or_, router, start
def test_simple_sequential_flow(): def test_simple_sequential_flow():
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
assert isinstance(received_events[3], MethodExecutionStartedEvent) assert isinstance(received_events[3], MethodExecutionStartedEvent)
assert received_events[3].method_name == "send_welcome_message" assert received_events[3].method_name == "send_welcome_message"
assert received_events[3].params == {} assert received_events[3].params == {}
assert getattr(received_events[3].state, "sent") is False assert received_events[3].state.sent is False
assert isinstance(received_events[4], MethodExecutionFinishedEvent) assert isinstance(received_events[4], MethodExecutionFinishedEvent)
assert received_events[4].method_name == "send_welcome_message" assert received_events[4].method_name == "send_welcome_message"
assert getattr(received_events[4].state, "sent") is True assert received_events[4].state.sent is True
assert received_events[4].result == "Welcome, Anakin!" assert received_events[4].result == "Welcome, Anakin!"
assert isinstance(received_events[5], FlowFinishedEvent) assert isinstance(received_events[5], FlowFinishedEvent)
@@ -894,3 +894,75 @@ def test_flow_name():
flow = MyFlow() flow = MyFlow()
assert flow.name == "MyFlow" assert flow.name == "MyFlow"
def test_nested_and_or_conditions():
"""Test nested conditions like or_(and_(A, B), and_(C, D)).
Reproduces bug from issue #3719 where nested conditions are flattened,
causing premature execution.
"""
execution_order = []
class NestedConditionFlow(Flow):
@start()
def method_1(self):
execution_order.append("method_1")
@listen(method_1)
def method_2(self):
execution_order.append("method_2")
@router(method_2)
def method_3(self):
execution_order.append("method_3")
# Choose b_condition path
return "b_condition"
@listen("b_condition")
def method_5(self):
execution_order.append("method_5")
@listen(method_5)
async def method_4(self):
execution_order.append("method_4")
@listen(or_("a_condition", "b_condition"))
async def method_6(self):
execution_order.append("method_6")
@listen(
or_(
and_("a_condition", method_6),
and_(method_6, method_4),
)
)
def method_7(self):
execution_order.append("method_7")
@listen(method_7)
async def method_8(self):
execution_order.append("method_8")
flow = NestedConditionFlow()
flow.kickoff()
# Verify execution happened
assert "method_1" in execution_order
assert "method_2" in execution_order
assert "method_3" in execution_order
assert "method_5" in execution_order
assert "method_4" in execution_order
assert "method_6" in execution_order
assert "method_7" in execution_order
assert "method_8" in execution_order
# Critical assertion: method_7 should only execute AFTER both method_6 AND method_4
# Since b_condition was returned, method_6 triggers on b_condition
# method_7 requires: (a_condition AND method_6) OR (method_6 AND method_4)
# The second condition (method_6 AND method_4) should be the one that triggers
assert execution_order.index("method_7") > execution_order.index("method_6")
assert execution_order.index("method_7") > execution_order.index("method_4")
# method_8 should execute after method_7
assert execution_order.index("method_8") > execution_order.index("method_7")