From d5f23c439a9ed90fa593ccc5b8194aba3901bfbe Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 22 Apr 2025 18:56:27 +0000 Subject: [PATCH] fix: clarify @listen decorator method vs output behavior (#2666) Co-Authored-By: Joe Moura --- docs/concepts/flows.mdx | 20 +++- src/crewai/flow/flow.py | 78 ++++++++++++--- tests/test_listen_decorator.py | 168 +++++++++++++++++++++++++++++++++ 3 files changed, 249 insertions(+), 17 deletions(-) create mode 100644 tests/test_listen_decorator.py diff --git a/docs/concepts/flows.mdx b/docs/concepts/flows.mdx index b7793c60c..87cb06d27 100644 --- a/docs/concepts/flows.mdx +++ b/docs/concepts/flows.mdx @@ -106,7 +106,23 @@ The `@listen()` decorator is used to mark a method as a listener for the output The `@listen()` decorator can be used in several ways: -1. **Listening to a Method by Name**: You can pass the name of the method you want to listen to as a string. When that method completes, the listener method will be triggered. +1. **Listening to a Method by Name Explicitly**: You can explicitly specify that you're listening for a method by name, which removes any ambiguity. + + ```python Code + @listen(method="generate_city") + def generate_fun_fact(self, random_city): + # Implementation + ``` + +2. **Listening to an Output Value Explicitly**: You can explicitly specify that you're listening for an output value, which removes any ambiguity. + + ```python Code + @listen(output="success") + def handle_success(self): + # Implementation + ``` + +3. **Legacy Usage - Listening to a Method by Name**: You can pass the name of the method you want to listen to as a string. When that method completes, the listener method will be triggered. ```python Code @listen("generate_city") @@ -114,7 +130,7 @@ The `@listen()` decorator can be used in several ways: # Implementation ``` -2. **Listening to a Method Directly**: You can pass the method itself. When that method completes, the listener method will be triggered. +4. **Legacy Usage - Listening to a Method Directly**: You can pass the method itself. When that method completes, the listener method will be triggered. ```python Code @listen(generate_city) def generate_fun_fact(self, random_city): diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 99ae82c96..96b0857c2 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -162,7 +162,9 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: return decorator -def listen(condition: Union[str, dict, Callable]) -> Callable: +def listen( + condition: Union[str, dict, Callable] = None, *, method: str = None, output: str = None +) -> Callable: """ Creates a listener that executes when specified conditions are met. @@ -172,11 +174,17 @@ def listen(condition: Union[str, dict, Callable]) -> Callable: Parameters ---------- - condition : Union[str, dict, Callable] - Specifies when the listener should execute. Can be: - - str: Name of a method that triggers this listener + condition : Union[str, dict, Callable], optional + Legacy parameter specifies when the listener should execute. Can be: + - str: Name of a method that triggers this listener or an output string - dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) - Callable: A method reference that triggers this listener + method : str, optional + Name of a method that triggers this listener. This explicitly indicates + you're listening for a method execution rather than an output string. + output : str, optional + Output string value that triggers this listener. This explicitly indicates + you're listening for an output value rather than a method name. Returns ------- @@ -186,23 +194,46 @@ def listen(condition: Union[str, dict, Callable]) -> Callable: Raises ------ ValueError - If the condition format is invalid. + If the parameters are invalid or incompatible. Examples -------- - >>> @listen("process_data") # Listen to single method + >>> @listen(method="process_data") # Explicitly listen to method name >>> def handle_processed_data(self): ... pass - - >>> @listen(or_("success", "failure")) # Listen to multiple methods + + >>> @listen(output="success") # Explicitly listen to output string + >>> def handle_success_case(self): + ... pass + + >>> @listen(or_("success", "failure")) # Listen to multiple outputs >>> def handle_completion(self): ... pass """ def decorator(func): + if method is not None and output is not None: + raise ValueError("Cannot specify both 'method' and 'output' parameters") + + if method is not None: + func.__trigger_methods__ = [method] + func.__condition_type__ = "OR" + func.__is_output__ = False + return func + + if output is not None: + func.__trigger_methods__ = [output] + func.__condition_type__ = "OR" + func.__is_output__ = True + return func + + if condition is None: + raise ValueError("Must provide either 'condition', 'method', or 'output' parameter") + if isinstance(condition, str): func.__trigger_methods__ = [condition] func.__condition_type__ = "OR" + func.__is_output__ = None elif ( isinstance(condition, dict) and "type" in condition @@ -210,9 +241,11 @@ def listen(condition: Union[str, dict, Callable]) -> Callable: ): func.__trigger_methods__ = condition["methods"] func.__condition_type__ = condition["type"] + func.__is_output__ = None elif callable(condition) and hasattr(condition, "__name__"): func.__trigger_methods__ = [condition.__name__] func.__condition_type__ = "OR" + func.__is_output__ = False else: raise ValueError( "Condition must be a method, string, or a result of or_() or and_()" @@ -897,10 +930,14 @@ class Flow(Generic[T], metaclass=FlowMeta): - Each router's result becomes a new trigger_method - Normal listeners are executed in parallel for efficiency - Listeners can receive the trigger method's result as a parameter + - Prevents infinite loops by tracking processed triggers """ + processed_triggers = set() + # First, handle routers repeatedly until no router triggers anymore router_results = [] current_trigger = trigger_method + processed_triggers.add(current_trigger) while True: routers_triggered = self._find_triggered_methods( @@ -913,11 +950,10 @@ class Flow(Generic[T], metaclass=FlowMeta): await self._execute_single_listener(router_name, result) # After executing router, the router's result is the path router_result = self._method_outputs[-1] - if router_result: # Only add non-None results + if router_result and router_result not in processed_triggers: # Only add non-None and unprocessed results router_results.append(router_result) - current_trigger = ( - router_result # Update for next iteration of router chain - ) + processed_triggers.add(router_result) + current_trigger = router_result # Update for next iteration of router chain # Now execute normal listeners for all router results and the original trigger all_triggers = [trigger_method] + router_results @@ -946,7 +982,7 @@ class Flow(Generic[T], metaclass=FlowMeta): Parameters ---------- trigger_method : str - The name of the method that just completed execution. + The name of the method that just completed execution or an output value. router_only : bool If True, only consider router methods. If False, only consider non-router methods. @@ -963,6 +999,7 @@ class Flow(Generic[T], metaclass=FlowMeta): * AND: Triggers only when all conditions are met - Maintains state for AND conditions using _pending_and_listeners - Separates router and normal listener evaluation + - Respects the __is_output__ attribute to disambiguate between method names and output strings """ triggered = [] for listener_name, (condition_type, methods) in self._listeners.items(): @@ -971,10 +1008,21 @@ class Flow(Generic[T], metaclass=FlowMeta): if router_only != is_router: continue + method = self._methods.get(listener_name) + is_output = getattr(method, "__is_output__", None) + if condition_type == "OR": - # If the trigger_method matches any in methods, run this + # For methods with explicit output=True, only match if trigger is not a method name + # For methods with explicit method=True, only match if trigger is a method name + if trigger_method in methods: - triggered.append(listener_name) + trigger_is_method = trigger_method in self._methods + + if (is_output is None or # Legacy behavior - always match + (is_output is True and not trigger_is_method) or # Output string listener + (is_output is False and trigger_is_method)): # Method name listener + triggered.append(listener_name) + elif condition_type == "AND": # Initialize pending methods for this listener if not already done if listener_name not in self._pending_and_listeners: diff --git a/tests/test_listen_decorator.py b/tests/test_listen_decorator.py new file mode 100644 index 000000000..fd0d0ec20 --- /dev/null +++ b/tests/test_listen_decorator.py @@ -0,0 +1,168 @@ +"""Test @listen decorator for method vs output disambiguation.""" + +import pytest +from pydantic import BaseModel + +from crewai.flow.flow import Flow, listen, router, start + + +def test_listen_with_explicit_method(): + """Test @listen with explicit method parameter.""" + execution_order = [] + + class ExplicitFlow(Flow): + @start() + def method_to_listen_for(self): + execution_order.append("method_to_listen_for") + return "method_output" + + @listen(method="method_to_listen_for") + def explicit_method_listener(self): + execution_order.append("explicit_method_listener") + + flow = ExplicitFlow() + flow.kickoff() + + assert "method_to_listen_for" in execution_order + assert "explicit_method_listener" in execution_order + assert execution_order.index("explicit_method_listener") > execution_order.index("method_to_listen_for") + + +def test_listen_with_explicit_output(): + """Test @listen with explicit output parameter.""" + execution_order = [] + + class ExplicitOutputFlow(Flow): + @start() + def start_method(self): + execution_order.append("start_method") + + @router(start_method) + def router_method(self): + execution_order.append("router_method") + return "output_value" + + @listen(output="output_value") + def output_listener(self): + execution_order.append("output_listener") + + flow = ExplicitOutputFlow() + flow.kickoff() + + assert "start_method" in execution_order + assert "router_method" in execution_order + assert "output_listener" in execution_order + assert execution_order.index("output_listener") > execution_order.index("router_method") + + +def test_ambiguous_case_with_explicit_parameters(): + """Test case where method name matches a possible output value.""" + import logging + import asyncio + import time + logging.basicConfig(level=logging.DEBUG) + + execution_order = [] + + class AmbiguousFlow(Flow): + @start() + def start_method(self): + print("Executing start_method") + execution_order.append("start_method") + return "start output" + + @router(start_method) + def router_method(self): + print("Executing router_method") + execution_order.append("router_method") + return "ambiguous_name" + + def ambiguous_name(self): + print("This method should not be called directly") + execution_order.append("ambiguous_name_direct_call") + return "should not happen" + + @listen(method="ambiguous_name") # Listen to method name explicitly + def method_listener(self): + print("Executing method_listener") + execution_order.append("method_listener") + + @listen(output="ambiguous_name") # Listen to output string explicitly + def output_listener(self): + print("Executing output_listener") + execution_order.append("output_listener") + + print("Creating flow instance") + flow = AmbiguousFlow() + + async def run_with_timeout(): + task = asyncio.create_task(flow.kickoff_async()) + + try: + await asyncio.wait_for(task, timeout=5.0) # 5 second timeout + except asyncio.TimeoutError: + print("Test timed out - likely an infinite loop") + return False + return True + + print("Starting flow kickoff with timeout") + success = asyncio.run(run_with_timeout()) + + print(f"Execution order: {execution_order}") + + if success: + assert "start_method" in execution_order + assert "router_method" in execution_order + assert "output_listener" in execution_order + + assert "method_listener" not in execution_order + + assert "ambiguous_name_direct_call" not in execution_order + + assert execution_order.index("output_listener") > execution_order.index("router_method") + else: + pytest.fail("Test timed out - likely an infinite loop in the flow execution") + + +def test_listen_with_backward_compatibility(): + """Test that the old way of using @listen still works.""" + execution_order = [] + + class BackwardCompatFlow(Flow): + @start() + def start_method(self): + execution_order.append("start_method") + + @router(start_method) + def router_method(self): + execution_order.append("router_method") + return "success" + + @listen("start_method") # Old way - listen to method by name + def method_listener(self): + execution_order.append("method_listener") + + @listen("success") # Old way - listen to output string + def output_listener(self): + execution_order.append("output_listener") + + flow = BackwardCompatFlow() + flow.kickoff() + + assert "start_method" in execution_order + assert "router_method" in execution_order + assert "method_listener" in execution_order + assert "output_listener" in execution_order + + +def test_listen_with_invalid_parameters(): + """Test that invalid parameters raise exceptions.""" + with pytest.raises(ValueError): + @listen(method="method_name", output="output_value") + def invalid_listener(self): + pass + + with pytest.raises(ValueError): + @listen() + def no_param_listener(self): + pass