diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 85bb077ee..84783b081 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow from crewai.flow.persistence.base import FlowPersistence from crewai.flow.types import FlowExecutionData from crewai.flow.utils import get_possible_return_constants -from crewai.utilities.printer import Printer +from crewai.utilities.printer import Printer, PrinterColor logger = logging.getLogger(__name__) @@ -105,7 +105,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable: condition : Optional[Union[str, dict, Callable]], optional Defines when the start method should execute. Can be: - str: Name of a method that triggers this start - - dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) + - dict: Result from or_() or and_(), including nested conditions - Callable: A method reference that triggers this start Default is None, meaning unconditional start. @@ -140,13 +140,18 @@ def start(condition: str | dict | Callable | None = None) -> Callable: if isinstance(condition, str): func.__trigger_methods__ = [condition] func.__condition_type__ = "OR" - elif ( - isinstance(condition, dict) - and "type" in condition - and "methods" in condition - ): - func.__trigger_methods__ = condition["methods"] - func.__condition_type__ = condition["type"] + elif isinstance(condition, dict) and "type" in condition: + if "conditions" in condition: + func.__trigger_condition__ = condition + func.__trigger_methods__ = _extract_all_methods(condition) + func.__condition_type__ = condition["type"] + elif "methods" in condition: + func.__trigger_methods__ = condition["methods"] + func.__condition_type__ = condition["type"] + else: + raise ValueError( + "Condition dict must contain 'conditions' or 'methods'" + ) elif callable(condition) and hasattr(condition, "__name__"): func.__trigger_methods__ = [condition.__name__] func.__condition_type__ = "OR" @@ -172,7 +177,7 @@ def listen(condition: str | dict | Callable) -> Callable: condition : Union[str, dict, Callable] Specifies when the listener should execute. Can be: - str: Name of a method that triggers this listener - - dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) + - dict: Result from or_() or and_(), including nested conditions - Callable: A method reference that triggers this listener Returns @@ -200,13 +205,18 @@ def listen(condition: str | dict | Callable) -> Callable: if isinstance(condition, str): func.__trigger_methods__ = [condition] func.__condition_type__ = "OR" - elif ( - isinstance(condition, dict) - and "type" in condition - and "methods" in condition - ): - func.__trigger_methods__ = condition["methods"] - func.__condition_type__ = condition["type"] + elif isinstance(condition, dict) and "type" in condition: + if "conditions" in condition: + func.__trigger_condition__ = condition + func.__trigger_methods__ = _extract_all_methods(condition) + func.__condition_type__ = condition["type"] + elif "methods" in condition: + func.__trigger_methods__ = condition["methods"] + func.__condition_type__ = condition["type"] + else: + raise ValueError( + "Condition dict must contain 'conditions' or 'methods'" + ) elif callable(condition) and hasattr(condition, "__name__"): func.__trigger_methods__ = [condition.__name__] func.__condition_type__ = "OR" @@ -233,7 +243,7 @@ def router(condition: str | dict | Callable) -> Callable: condition : Union[str, dict, Callable] Specifies when the router should execute. Can be: - str: Name of a method that triggers this router - - dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) + - dict: Result from or_() or and_(), including nested conditions - Callable: A method reference that triggers this router Returns @@ -266,13 +276,18 @@ def router(condition: str | dict | Callable) -> Callable: if isinstance(condition, str): func.__trigger_methods__ = [condition] func.__condition_type__ = "OR" - elif ( - isinstance(condition, dict) - and "type" in condition - and "methods" in condition - ): - func.__trigger_methods__ = condition["methods"] - func.__condition_type__ = condition["type"] + elif isinstance(condition, dict) and "type" in condition: + if "conditions" in condition: + func.__trigger_condition__ = condition + func.__trigger_methods__ = _extract_all_methods(condition) + func.__condition_type__ = condition["type"] + elif "methods" in condition: + func.__trigger_methods__ = condition["methods"] + func.__condition_type__ = condition["type"] + else: + raise ValueError( + "Condition dict must contain 'conditions' or 'methods'" + ) elif callable(condition) and hasattr(condition, "__name__"): func.__trigger_methods__ = [condition.__name__] func.__condition_type__ = "OR" @@ -298,14 +313,15 @@ def or_(*conditions: str | dict | Callable) -> dict: *conditions : Union[str, dict, Callable] Variable number of conditions that can be: - str: Method names - - dict: Existing condition dictionaries + - dict: Existing condition dictionaries (nested conditions) - Callable: Method references Returns ------- dict A condition dictionary with format: - {"type": "OR", "methods": list_of_method_names} + {"type": "OR", "conditions": list_of_conditions} + where each condition can be a string (method name) or a nested dict Raises ------ @@ -317,18 +333,22 @@ def or_(*conditions: str | dict | Callable) -> dict: >>> @listen(or_("success", "timeout")) >>> def handle_completion(self): ... pass + + >>> @listen(or_(and_("step1", "step2"), "step3")) + >>> def handle_nested(self): + ... pass """ - methods = [] + processed_conditions: list[str | dict[str, Any]] = [] for condition in conditions: - if isinstance(condition, dict) and "methods" in condition: - methods.extend(condition["methods"]) + if isinstance(condition, dict): + processed_conditions.append(condition) elif isinstance(condition, str): - methods.append(condition) + processed_conditions.append(condition) elif callable(condition): - methods.append(getattr(condition, "__name__", repr(condition))) + processed_conditions.append(getattr(condition, "__name__", repr(condition))) else: raise ValueError("Invalid condition in or_()") - return {"type": "OR", "methods": methods} + return {"type": "OR", "conditions": processed_conditions} def and_(*conditions: str | dict | Callable) -> dict: @@ -344,14 +364,15 @@ def and_(*conditions: str | dict | Callable) -> dict: *conditions : Union[str, dict, Callable] Variable number of conditions that can be: - str: Method names - - dict: Existing condition dictionaries + - dict: Existing condition dictionaries (nested conditions) - Callable: Method references Returns ------- dict A condition dictionary with format: - {"type": "AND", "methods": list_of_method_names} + {"type": "AND", "conditions": list_of_conditions} + where each condition can be a string (method name) or a nested dict Raises ------ @@ -363,18 +384,69 @@ def and_(*conditions: str | dict | Callable) -> dict: >>> @listen(and_("validated", "processed")) >>> def handle_complete_data(self): ... pass + + >>> @listen(and_(or_("step1", "step2"), "step3")) + >>> def handle_nested(self): + ... pass """ - methods = [] + processed_conditions: list[str | dict[str, Any]] = [] for condition in conditions: - if isinstance(condition, dict) and "methods" in condition: - methods.extend(condition["methods"]) + if isinstance(condition, dict): + processed_conditions.append(condition) elif isinstance(condition, str): - methods.append(condition) + processed_conditions.append(condition) elif callable(condition): - methods.append(getattr(condition, "__name__", repr(condition))) + processed_conditions.append(getattr(condition, "__name__", repr(condition))) else: raise ValueError("Invalid condition in and_()") - return {"type": "AND", "methods": methods} + return {"type": "AND", "conditions": processed_conditions} + + +def _normalize_condition(condition: str | dict | list) -> dict: + """Normalize a condition to standard format with 'conditions' key. + + Args: + condition: Can be a string (method name), dict (condition), or list + + Returns: + Normalized dict with 'type' and 'conditions' keys + """ + if isinstance(condition, str): + return {"type": "OR", "conditions": [condition]} + if isinstance(condition, dict): + if "conditions" in condition: + return condition + if "methods" in condition: + return {"type": condition["type"], "conditions": condition["methods"]} + return condition + if isinstance(condition, list): + return {"type": "OR", "conditions": condition} + return {"type": "OR", "conditions": [condition]} + + +def _extract_all_methods(condition: str | dict | list) -> list[str]: + """Extract all method names from a condition (including nested). + + Args: + condition: Can be a string, dict, or list + + Returns: + List of all method names in the condition tree + """ + if isinstance(condition, str): + return [condition] + if isinstance(condition, dict): + normalized = _normalize_condition(condition) + methods = [] + for sub_cond in normalized.get("conditions", []): + methods.extend(_extract_all_methods(sub_cond)) + return methods + if isinstance(condition, list): + methods = [] + for item in condition: + methods.extend(_extract_all_methods(item)) + return methods + return [] class FlowMeta(type): @@ -402,7 +474,10 @@ class FlowMeta(type): if hasattr(attr_value, "__trigger_methods__"): methods = attr_value.__trigger_methods__ condition_type = getattr(attr_value, "__condition_type__", "OR") - listeners[attr_name] = (condition_type, methods) + if hasattr(attr_value, "__trigger_condition__"): + listeners[attr_name] = attr_value.__trigger_condition__ + else: + listeners[attr_name] = (condition_type, methods) if ( hasattr(attr_value, "__is_router__") @@ -822,6 +897,7 @@ class Flow(Generic[T], metaclass=FlowMeta): # Clear completed methods and outputs for a fresh start self._completed_methods.clear() self._method_outputs.clear() + self._pending_and_listeners.clear() else: # We're restoring from persistence, set the flag self._is_execution_resuming = True @@ -1086,10 +1162,16 @@ class Flow(Generic[T], metaclass=FlowMeta): for method_name in self._start_methods: # Check if this start method is triggered by the current trigger if method_name in self._listeners: - condition_type, trigger_methods = self._listeners[ - method_name - ] - if current_trigger in trigger_methods: + condition_data = self._listeners[method_name] + should_trigger = False + if isinstance(condition_data, tuple): + _, trigger_methods = condition_data + should_trigger = current_trigger in trigger_methods + elif isinstance(condition_data, dict): + all_methods = _extract_all_methods(condition_data) + should_trigger = current_trigger in all_methods + + if should_trigger: # Only execute if this is a cycle (method was already completed) if method_name in self._completed_methods: # For router-triggered start methods in cycles, temporarily clear resumption flag @@ -1099,6 +1181,51 @@ class Flow(Generic[T], metaclass=FlowMeta): await self._execute_start_method(method_name) self._is_execution_resuming = was_resuming + def _evaluate_condition( + self, condition: str | dict, trigger_method: str, listener_name: str + ) -> bool: + """Recursively evaluate a condition (simple or nested). + + Args: + condition: Can be a string (method name) or dict (nested condition) + trigger_method: The method that just completed + listener_name: Name of the listener being evaluated + + Returns: + True if the condition is satisfied, False otherwise + """ + if isinstance(condition, str): + return condition == trigger_method + + if isinstance(condition, dict): + normalized = _normalize_condition(condition) + cond_type = normalized.get("type", "OR") + sub_conditions = normalized.get("conditions", []) + + if cond_type == "OR": + return any( + self._evaluate_condition(sub_cond, trigger_method, listener_name) + for sub_cond in sub_conditions + ) + + if cond_type == "AND": + pending_key = f"{listener_name}:{id(condition)}" + + if pending_key not in self._pending_and_listeners: + all_methods = set(_extract_all_methods(condition)) + self._pending_and_listeners[pending_key] = all_methods + + if trigger_method in self._pending_and_listeners[pending_key]: + self._pending_and_listeners[pending_key].discard(trigger_method) + + if not self._pending_and_listeners[pending_key]: + self._pending_and_listeners.pop(pending_key, None) + return True + + return False + + return False + def _find_triggered_methods( self, trigger_method: str, router_only: bool ) -> list[str]: @@ -1106,7 +1233,7 @@ class Flow(Generic[T], metaclass=FlowMeta): Finds all methods that should be triggered based on conditions. This internal method evaluates both OR and AND conditions to determine - which methods should be executed next in the flow. + which methods should be executed next in the flow. Supports nested conditions. Parameters ---------- @@ -1123,14 +1250,13 @@ class Flow(Generic[T], metaclass=FlowMeta): Notes ----- - - Handles both OR and AND conditions: - * OR: Triggers if any condition is met - * AND: Triggers only when all conditions are met + - Handles both OR and AND conditions, including nested combinations - Maintains state for AND conditions using _pending_and_listeners - Separates router and normal listener evaluation """ triggered = [] - for listener_name, (condition_type, methods) in self._listeners.items(): + + for listener_name, condition_data in self._listeners.items(): is_router = listener_name in self._routers if router_only != is_router: @@ -1139,23 +1265,29 @@ class Flow(Generic[T], metaclass=FlowMeta): if not router_only and listener_name in self._start_methods: continue - if condition_type == "OR": - # If the trigger_method matches any in methods, run this - if trigger_method in methods: - triggered.append(listener_name) - elif condition_type == "AND": - # Initialize pending methods for this listener if not already done - if listener_name not in self._pending_and_listeners: - self._pending_and_listeners[listener_name] = set(methods) - # Remove the trigger method from pending methods - if trigger_method in self._pending_and_listeners[listener_name]: - self._pending_and_listeners[listener_name].discard(trigger_method) + if isinstance(condition_data, tuple): + condition_type, methods = condition_data - if not self._pending_and_listeners[listener_name]: - # All required methods have been executed + if condition_type == "OR": + if trigger_method in methods: + triggered.append(listener_name) + elif condition_type == "AND": + if listener_name not in self._pending_and_listeners: + self._pending_and_listeners[listener_name] = set(methods) + if trigger_method in self._pending_and_listeners[listener_name]: + self._pending_and_listeners[listener_name].discard( + trigger_method + ) + + if not self._pending_and_listeners[listener_name]: + triggered.append(listener_name) + self._pending_and_listeners.pop(listener_name, None) + + elif isinstance(condition_data, dict): + if self._evaluate_condition( + condition_data, trigger_method, listener_name + ): triggered.append(listener_name) - # Reset pending methods for this listener - self._pending_and_listeners.pop(listener_name, None) return triggered @@ -1218,7 +1350,7 @@ class Flow(Generic[T], metaclass=FlowMeta): raise def _log_flow_event( - self, message: str, color: str = "yellow", level: str = "info" + self, message: str, color: PrinterColor | None = "yellow", level: str = "info" ) -> None: """Centralized logging method for flow events. diff --git a/tests/test_flow.py b/tests/test_flow.py index 504cf8e6e..f060a7a19 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -6,15 +6,15 @@ from datetime import datetime import pytest from pydantic import BaseModel -from crewai.flow.flow import Flow, and_, listen, or_, router, start from crewai.events.event_bus import crewai_event_bus from crewai.events.types.flow_events import ( FlowFinishedEvent, - FlowStartedEvent, FlowPlotEvent, + FlowStartedEvent, MethodExecutionFinishedEvent, MethodExecutionStartedEvent, ) +from crewai.flow.flow import Flow, and_, listen, or_, router, start def test_simple_sequential_flow(): @@ -679,11 +679,11 @@ def test_structured_flow_event_emission(): assert isinstance(received_events[3], MethodExecutionStartedEvent) assert received_events[3].method_name == "send_welcome_message" assert received_events[3].params == {} - assert getattr(received_events[3].state, "sent") is False + assert received_events[3].state.sent is False assert isinstance(received_events[4], MethodExecutionFinishedEvent) assert received_events[4].method_name == "send_welcome_message" - assert getattr(received_events[4].state, "sent") is True + assert received_events[4].state.sent is True assert received_events[4].result == "Welcome, Anakin!" assert isinstance(received_events[5], FlowFinishedEvent) @@ -894,3 +894,75 @@ def test_flow_name(): flow = MyFlow() assert flow.name == "MyFlow" + + +def test_nested_and_or_conditions(): + """Test nested conditions like or_(and_(A, B), and_(C, D)). + + Reproduces bug from issue #3719 where nested conditions are flattened, + causing premature execution. + """ + execution_order = [] + + class NestedConditionFlow(Flow): + @start() + def method_1(self): + execution_order.append("method_1") + + @listen(method_1) + def method_2(self): + execution_order.append("method_2") + + @router(method_2) + def method_3(self): + execution_order.append("method_3") + # Choose b_condition path + return "b_condition" + + @listen("b_condition") + def method_5(self): + execution_order.append("method_5") + + @listen(method_5) + async def method_4(self): + execution_order.append("method_4") + + @listen(or_("a_condition", "b_condition")) + async def method_6(self): + execution_order.append("method_6") + + @listen( + or_( + and_("a_condition", method_6), + and_(method_6, method_4), + ) + ) + def method_7(self): + execution_order.append("method_7") + + @listen(method_7) + async def method_8(self): + execution_order.append("method_8") + + flow = NestedConditionFlow() + flow.kickoff() + + # Verify execution happened + assert "method_1" in execution_order + assert "method_2" in execution_order + assert "method_3" in execution_order + assert "method_5" in execution_order + assert "method_4" in execution_order + assert "method_6" in execution_order + assert "method_7" in execution_order + assert "method_8" in execution_order + + # Critical assertion: method_7 should only execute AFTER both method_6 AND method_4 + # Since b_condition was returned, method_6 triggers on b_condition + # method_7 requires: (a_condition AND method_6) OR (method_6 AND method_4) + # The second condition (method_6 AND method_4) should be the one that triggers + assert execution_order.index("method_7") > execution_order.index("method_6") + assert execution_order.index("method_7") > execution_order.index("method_4") + + # method_8 should execute after method_7 + assert execution_order.index("method_8") > execution_order.index("method_7")