diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index ce206fbf5..f7e6dd84b 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -32,7 +32,7 @@ from crewai.flow.flow_visualizer import plot_flow from crewai.flow.persistence.base import FlowPersistence from crewai.flow.types import FlowExecutionData from crewai.flow.utils import get_possible_return_constants -from crewai.utilities.printer import Printer +from crewai.utilities.printer import Printer, PrinterColor logger = logging.getLogger(__name__) @@ -56,19 +56,6 @@ StateT = TypeVar( def ensure_state_type(state: Any, expected_type: type[StateT]) -> StateT: """Ensure state matches expected type with proper validation. - Args: - state: State instance to validate - expected_type: Expected type for the state - - Returns: - Validated state instance - - Raises: - TypeError: If state doesn't match expected type - ValueError: If state validation fails - """ - """Ensure state matches expected type with proper validation. - Args: state: State instance to validate expected_type: Expected type for the state @@ -106,7 +93,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable: condition : Optional[Union[str, dict, Callable]], optional Defines when the start method should execute. Can be: - str: Name of a method that triggers this start - - dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) + - dict: Result from or_() or and_(), including nested conditions - Callable: A method reference that triggers this start Default is None, meaning unconditional start. @@ -141,13 +128,18 @@ def start(condition: str | dict | Callable | None = None) -> Callable: if isinstance(condition, str): func.__trigger_methods__ = [condition] func.__condition_type__ = "OR" - elif ( - isinstance(condition, dict) - and "type" in condition - and "methods" in condition - ): - func.__trigger_methods__ = condition["methods"] - func.__condition_type__ = condition["type"] + elif isinstance(condition, dict) and "type" in condition: + if "conditions" in condition: + func.__trigger_condition__ = condition + func.__trigger_methods__ = _extract_all_methods(condition) + func.__condition_type__ = condition["type"] + elif "methods" in condition: + func.__trigger_methods__ = condition["methods"] + func.__condition_type__ = condition["type"] + else: + raise ValueError( + "Condition dict must contain 'conditions' or 'methods'" + ) elif callable(condition) and hasattr(condition, "__name__"): func.__trigger_methods__ = [condition.__name__] func.__condition_type__ = "OR" @@ -173,7 +165,7 @@ def listen(condition: str | dict | Callable) -> Callable: condition : Union[str, dict, Callable] Specifies when the listener should execute. Can be: - str: Name of a method that triggers this listener - - dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) + - dict: Result from or_() or and_(), including nested conditions - Callable: A method reference that triggers this listener Returns @@ -201,13 +193,18 @@ def listen(condition: str | dict | Callable) -> Callable: if isinstance(condition, str): func.__trigger_methods__ = [condition] func.__condition_type__ = "OR" - elif ( - isinstance(condition, dict) - and "type" in condition - and "methods" in condition - ): - func.__trigger_methods__ = condition["methods"] - func.__condition_type__ = condition["type"] + elif isinstance(condition, dict) and "type" in condition: + if "conditions" in condition: + func.__trigger_condition__ = condition + func.__trigger_methods__ = _extract_all_methods(condition) + func.__condition_type__ = condition["type"] + elif "methods" in condition: + func.__trigger_methods__ = condition["methods"] + func.__condition_type__ = condition["type"] + else: + raise ValueError( + "Condition dict must contain 'conditions' or 'methods'" + ) elif callable(condition) and hasattr(condition, "__name__"): func.__trigger_methods__ = [condition.__name__] func.__condition_type__ = "OR" @@ -234,7 +231,7 @@ def router(condition: str | dict | Callable) -> Callable: condition : Union[str, dict, Callable] Specifies when the router should execute. Can be: - str: Name of a method that triggers this router - - dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) + - dict: Result from or_() or and_(), including nested conditions - Callable: A method reference that triggers this router Returns @@ -267,13 +264,18 @@ def router(condition: str | dict | Callable) -> Callable: if isinstance(condition, str): func.__trigger_methods__ = [condition] func.__condition_type__ = "OR" - elif ( - isinstance(condition, dict) - and "type" in condition - and "methods" in condition - ): - func.__trigger_methods__ = condition["methods"] - func.__condition_type__ = condition["type"] + elif isinstance(condition, dict) and "type" in condition: + if "conditions" in condition: + func.__trigger_condition__ = condition + func.__trigger_methods__ = _extract_all_methods(condition) + func.__condition_type__ = condition["type"] + elif "methods" in condition: + func.__trigger_methods__ = condition["methods"] + func.__condition_type__ = condition["type"] + else: + raise ValueError( + "Condition dict must contain 'conditions' or 'methods'" + ) elif callable(condition) and hasattr(condition, "__name__"): func.__trigger_methods__ = [condition.__name__] func.__condition_type__ = "OR" @@ -299,14 +301,15 @@ def or_(*conditions: str | dict | Callable) -> dict: *conditions : Union[str, dict, Callable] Variable number of conditions that can be: - str: Method names - - dict: Existing condition dictionaries + - dict: Existing condition dictionaries (nested conditions) - Callable: Method references Returns ------- dict A condition dictionary with format: - {"type": "OR", "methods": list_of_method_names} + {"type": "OR", "conditions": list_of_conditions} + where each condition can be a string (method name) or a nested dict Raises ------ @@ -318,18 +321,22 @@ def or_(*conditions: str | dict | Callable) -> dict: >>> @listen(or_("success", "timeout")) >>> def handle_completion(self): ... pass + + >>> @listen(or_(and_("step1", "step2"), "step3")) + >>> def handle_nested(self): + ... pass """ - methods = [] + processed_conditions: list[str | dict[str, Any]] = [] for condition in conditions: - if isinstance(condition, dict) and "methods" in condition: - methods.extend(condition["methods"]) + if isinstance(condition, dict): + processed_conditions.append(condition) elif isinstance(condition, str): - methods.append(condition) + processed_conditions.append(condition) elif callable(condition): - methods.append(getattr(condition, "__name__", repr(condition))) + processed_conditions.append(getattr(condition, "__name__", repr(condition))) else: raise ValueError("Invalid condition in or_()") - return {"type": "OR", "methods": methods} + return {"type": "OR", "conditions": processed_conditions} def and_(*conditions: str | dict | Callable) -> dict: @@ -345,14 +352,15 @@ def and_(*conditions: str | dict | Callable) -> dict: *conditions : Union[str, dict, Callable] Variable number of conditions that can be: - str: Method names - - dict: Existing condition dictionaries + - dict: Existing condition dictionaries (nested conditions) - Callable: Method references Returns ------- dict A condition dictionary with format: - {"type": "AND", "methods": list_of_method_names} + {"type": "AND", "conditions": list_of_conditions} + where each condition can be a string (method name) or a nested dict Raises ------ @@ -364,18 +372,69 @@ def and_(*conditions: str | dict | Callable) -> dict: >>> @listen(and_("validated", "processed")) >>> def handle_complete_data(self): ... pass + + >>> @listen(and_(or_("step1", "step2"), "step3")) + >>> def handle_nested(self): + ... pass """ - methods = [] + processed_conditions: list[str | dict[str, Any]] = [] for condition in conditions: - if isinstance(condition, dict) and "methods" in condition: - methods.extend(condition["methods"]) + if isinstance(condition, dict): + processed_conditions.append(condition) elif isinstance(condition, str): - methods.append(condition) + processed_conditions.append(condition) elif callable(condition): - methods.append(getattr(condition, "__name__", repr(condition))) + processed_conditions.append(getattr(condition, "__name__", repr(condition))) else: raise ValueError("Invalid condition in and_()") - return {"type": "AND", "methods": methods} + return {"type": "AND", "conditions": processed_conditions} + + +def _normalize_condition(condition: str | dict | list) -> dict: + """Normalize a condition to standard format with 'conditions' key. + + Args: + condition: Can be a string (method name), dict (condition), or list + + Returns: + Normalized dict with 'type' and 'conditions' keys + """ + if isinstance(condition, str): + return {"type": "OR", "conditions": [condition]} + if isinstance(condition, dict): + if "conditions" in condition: + return condition + if "methods" in condition: + return {"type": condition["type"], "conditions": condition["methods"]} + return condition + if isinstance(condition, list): + return {"type": "OR", "conditions": condition} + return {"type": "OR", "conditions": [condition]} + + +def _extract_all_methods(condition: str | dict | list) -> list[str]: + """Extract all method names from a condition (including nested). + + Args: + condition: Can be a string, dict, or list + + Returns: + List of all method names in the condition tree + """ + if isinstance(condition, str): + return [condition] + if isinstance(condition, dict): + normalized = _normalize_condition(condition) + methods = [] + for sub_cond in normalized.get("conditions", []): + methods.extend(_extract_all_methods(sub_cond)) + return methods + if isinstance(condition, list): + methods = [] + for item in condition: + methods.extend(_extract_all_methods(item)) + return methods + return [] class FlowMeta(type): @@ -403,7 +462,10 @@ class FlowMeta(type): if hasattr(attr_value, "__trigger_methods__"): methods = attr_value.__trigger_methods__ condition_type = getattr(attr_value, "__condition_type__", "OR") - listeners[attr_name] = (condition_type, methods) + if hasattr(attr_value, "__trigger_condition__"): + listeners[attr_name] = attr_value.__trigger_condition__ + else: + listeners[attr_name] = (condition_type, methods) if ( hasattr(attr_value, "__is_router__") @@ -824,6 +886,7 @@ class Flow(Generic[T], metaclass=FlowMeta): # Clear completed methods and outputs for a fresh start self._completed_methods.clear() self._method_outputs.clear() + self._pending_and_listeners.clear() else: # We're restoring from persistence, set the flag self._is_execution_resuming = True @@ -1115,10 +1178,16 @@ class Flow(Generic[T], metaclass=FlowMeta): for method_name in self._start_methods: # Check if this start method is triggered by the current trigger if method_name in self._listeners: - _condition_type, trigger_methods = self._listeners[ - method_name - ] - if current_trigger in trigger_methods: + condition_data = self._listeners[method_name] + should_trigger = False + if isinstance(condition_data, tuple): + _, trigger_methods = condition_data + should_trigger = current_trigger in trigger_methods + elif isinstance(condition_data, dict): + all_methods = _extract_all_methods(condition_data) + should_trigger = current_trigger in all_methods + + if should_trigger: # Only execute if this is a cycle (method was already completed) if method_name in self._completed_methods: # For router-triggered start methods in cycles, temporarily clear resumption flag @@ -1128,6 +1197,51 @@ class Flow(Generic[T], metaclass=FlowMeta): await self._execute_start_method(method_name) self._is_execution_resuming = was_resuming + def _evaluate_condition( + self, condition: str | dict, trigger_method: str, listener_name: str + ) -> bool: + """Recursively evaluate a condition (simple or nested). + + Args: + condition: Can be a string (method name) or dict (nested condition) + trigger_method: The method that just completed + listener_name: Name of the listener being evaluated + + Returns: + True if the condition is satisfied, False otherwise + """ + if isinstance(condition, str): + return condition == trigger_method + + if isinstance(condition, dict): + normalized = _normalize_condition(condition) + cond_type = normalized.get("type", "OR") + sub_conditions = normalized.get("conditions", []) + + if cond_type == "OR": + return any( + self._evaluate_condition(sub_cond, trigger_method, listener_name) + for sub_cond in sub_conditions + ) + + if cond_type == "AND": + pending_key = f"{listener_name}:{id(condition)}" + + if pending_key not in self._pending_and_listeners: + all_methods = set(_extract_all_methods(condition)) + self._pending_and_listeners[pending_key] = all_methods + + if trigger_method in self._pending_and_listeners[pending_key]: + self._pending_and_listeners[pending_key].discard(trigger_method) + + if not self._pending_and_listeners[pending_key]: + self._pending_and_listeners.pop(pending_key, None) + return True + + return False + + return False + def _find_triggered_methods( self, trigger_method: str, router_only: bool ) -> list[str]: @@ -1135,7 +1249,7 @@ class Flow(Generic[T], metaclass=FlowMeta): Finds all methods that should be triggered based on conditions. This internal method evaluates both OR and AND conditions to determine - which methods should be executed next in the flow. + which methods should be executed next in the flow. Supports nested conditions. Parameters ---------- @@ -1152,14 +1266,13 @@ class Flow(Generic[T], metaclass=FlowMeta): Notes ----- - - Handles both OR and AND conditions: - * OR: Triggers if any condition is met - * AND: Triggers only when all conditions are met + - Handles both OR and AND conditions, including nested combinations - Maintains state for AND conditions using _pending_and_listeners - Separates router and normal listener evaluation """ triggered = [] - for listener_name, (condition_type, methods) in self._listeners.items(): + + for listener_name, condition_data in self._listeners.items(): is_router = listener_name in self._routers if router_only != is_router: @@ -1168,23 +1281,29 @@ class Flow(Generic[T], metaclass=FlowMeta): if not router_only and listener_name in self._start_methods: continue - if condition_type == "OR": - # If the trigger_method matches any in methods, run this - if trigger_method in methods: - triggered.append(listener_name) - elif condition_type == "AND": - # Initialize pending methods for this listener if not already done - if listener_name not in self._pending_and_listeners: - self._pending_and_listeners[listener_name] = set(methods) - # Remove the trigger method from pending methods - if trigger_method in self._pending_and_listeners[listener_name]: - self._pending_and_listeners[listener_name].discard(trigger_method) + if isinstance(condition_data, tuple): + condition_type, methods = condition_data - if not self._pending_and_listeners[listener_name]: - # All required methods have been executed + if condition_type == "OR": + if trigger_method in methods: + triggered.append(listener_name) + elif condition_type == "AND": + if listener_name not in self._pending_and_listeners: + self._pending_and_listeners[listener_name] = set(methods) + if trigger_method in self._pending_and_listeners[listener_name]: + self._pending_and_listeners[listener_name].discard( + trigger_method + ) + + if not self._pending_and_listeners[listener_name]: + triggered.append(listener_name) + self._pending_and_listeners.pop(listener_name, None) + + elif isinstance(condition_data, dict): + if self._evaluate_condition( + condition_data, trigger_method, listener_name + ): triggered.append(listener_name) - # Reset pending methods for this listener - self._pending_and_listeners.pop(listener_name, None) return triggered @@ -1247,7 +1366,7 @@ class Flow(Generic[T], metaclass=FlowMeta): raise def _log_flow_event( - self, message: str, color: str = "yellow", level: str = "info" + self, message: str, color: PrinterColor | None = "yellow", level: str = "info" ) -> None: """Centralized logging method for flow events. diff --git a/lib/crewai/src/crewai/utilities/printer.py b/lib/crewai/src/crewai/utilities/printer.py index 18ed3ed5b..caeaf60b8 100644 --- a/lib/crewai/src/crewai/utilities/printer.py +++ b/lib/crewai/src/crewai/utilities/printer.py @@ -1,6 +1,11 @@ """Utility for colored console output.""" -from typing import Final, Literal, NamedTuple +from __future__ import annotations + +from typing import TYPE_CHECKING, Final, Literal, NamedTuple + +if TYPE_CHECKING: + from _typeshed import SupportsWrite PrinterColor = Literal[ "purple", @@ -54,13 +59,22 @@ class Printer: @staticmethod def print( - content: str | list[ColoredText], color: PrinterColor | None = None + content: str | list[ColoredText], + color: PrinterColor | None = None, + sep: str | None = " ", + end: str | None = "\n", + file: SupportsWrite[str] | None = None, + flush: Literal[False] = False, ) -> None: """Prints content to the console with optional color formatting. Args: content: Either a string or a list of ColoredText objects for multicolor output. color: Optional color for the text when content is a string. Ignored when content is a list. + sep: Separator to use between the text and color. + end: String appended after the last value. + file: A file-like object (stream); defaults to the current sys.stdout. + flush: Whether to forcibly flush the stream. """ if isinstance(content, str): content = [ColoredText(content, color)] @@ -68,5 +82,9 @@ class Printer: "".join( f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}" for c in content - ) + ), + sep=sep, + end=end, + file=file, + flush=flush, ) diff --git a/lib/crewai/tests/test_flow.py b/lib/crewai/tests/test_flow.py index a4bc08f93..8142b6491 100644 --- a/lib/crewai/tests/test_flow.py +++ b/lib/crewai/tests/test_flow.py @@ -961,3 +961,75 @@ def test_flow_name(): flow = MyFlow() assert flow.name == "MyFlow" + + +def test_nested_and_or_conditions(): + """Test nested conditions like or_(and_(A, B), and_(C, D)). + + Reproduces bug from issue #3719 where nested conditions are flattened, + causing premature execution. + """ + execution_order = [] + + class NestedConditionFlow(Flow): + @start() + def method_1(self): + execution_order.append("method_1") + + @listen(method_1) + def method_2(self): + execution_order.append("method_2") + + @router(method_2) + def method_3(self): + execution_order.append("method_3") + # Choose b_condition path + return "b_condition" + + @listen("b_condition") + def method_5(self): + execution_order.append("method_5") + + @listen(method_5) + async def method_4(self): + execution_order.append("method_4") + + @listen(or_("a_condition", "b_condition")) + async def method_6(self): + execution_order.append("method_6") + + @listen( + or_( + and_("a_condition", method_6), + and_(method_6, method_4), + ) + ) + def method_7(self): + execution_order.append("method_7") + + @listen(method_7) + async def method_8(self): + execution_order.append("method_8") + + flow = NestedConditionFlow() + flow.kickoff() + + # Verify execution happened + assert "method_1" in execution_order + assert "method_2" in execution_order + assert "method_3" in execution_order + assert "method_5" in execution_order + assert "method_4" in execution_order + assert "method_6" in execution_order + assert "method_7" in execution_order + assert "method_8" in execution_order + + # Critical assertion: method_7 should only execute AFTER both method_6 AND method_4 + # Since b_condition was returned, method_6 triggers on b_condition + # method_7 requires: (a_condition AND method_6) OR (method_6 AND method_4) + # The second condition (method_6 AND method_4) should be the one that triggers + assert execution_order.index("method_7") > execution_order.index("method_6") + assert execution_order.index("method_7") > execution_order.index("method_4") + + # method_8 should execute after method_7 + assert execution_order.index("method_8") > execution_order.index("method_7")