mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 09:08:14 +00:00
fix: clarify @listen decorator method vs output behavior (#2666)
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -106,7 +106,23 @@ The `@listen()` decorator is used to mark a method as a listener for the output
|
|||||||
|
|
||||||
The `@listen()` decorator can be used in several ways:
|
The `@listen()` decorator can be used in several ways:
|
||||||
|
|
||||||
1. **Listening to a Method by Name**: You can pass the name of the method you want to listen to as a string. When that method completes, the listener method will be triggered.
|
1. **Listening to a Method by Name Explicitly**: You can explicitly specify that you're listening for a method by name, which removes any ambiguity.
|
||||||
|
|
||||||
|
```python Code
|
||||||
|
@listen(method="generate_city")
|
||||||
|
def generate_fun_fact(self, random_city):
|
||||||
|
# Implementation
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Listening to an Output Value Explicitly**: You can explicitly specify that you're listening for an output value, which removes any ambiguity.
|
||||||
|
|
||||||
|
```python Code
|
||||||
|
@listen(output="success")
|
||||||
|
def handle_success(self):
|
||||||
|
# Implementation
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Legacy Usage - Listening to a Method by Name**: You can pass the name of the method you want to listen to as a string. When that method completes, the listener method will be triggered.
|
||||||
|
|
||||||
```python Code
|
```python Code
|
||||||
@listen("generate_city")
|
@listen("generate_city")
|
||||||
@@ -114,7 +130,7 @@ The `@listen()` decorator can be used in several ways:
|
|||||||
# Implementation
|
# Implementation
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Listening to a Method Directly**: You can pass the method itself. When that method completes, the listener method will be triggered.
|
4. **Legacy Usage - Listening to a Method Directly**: You can pass the method itself. When that method completes, the listener method will be triggered.
|
||||||
```python Code
|
```python Code
|
||||||
@listen(generate_city)
|
@listen(generate_city)
|
||||||
def generate_fun_fact(self, random_city):
|
def generate_fun_fact(self, random_city):
|
||||||
|
|||||||
@@ -162,7 +162,9 @@ def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def listen(condition: Union[str, dict, Callable]) -> Callable:
|
def listen(
|
||||||
|
condition: Union[str, dict, Callable] = None, *, method: str = None, output: str = None
|
||||||
|
) -> Callable:
|
||||||
"""
|
"""
|
||||||
Creates a listener that executes when specified conditions are met.
|
Creates a listener that executes when specified conditions are met.
|
||||||
|
|
||||||
@@ -172,11 +174,17 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
condition : Union[str, dict, Callable]
|
condition : Union[str, dict, Callable], optional
|
||||||
Specifies when the listener should execute. Can be:
|
Legacy parameter specifies when the listener should execute. Can be:
|
||||||
- str: Name of a method that triggers this listener
|
- str: Name of a method that triggers this listener or an output string
|
||||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||||
- Callable: A method reference that triggers this listener
|
- Callable: A method reference that triggers this listener
|
||||||
|
method : str, optional
|
||||||
|
Name of a method that triggers this listener. This explicitly indicates
|
||||||
|
you're listening for a method execution rather than an output string.
|
||||||
|
output : str, optional
|
||||||
|
Output string value that triggers this listener. This explicitly indicates
|
||||||
|
you're listening for an output value rather than a method name.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@@ -186,23 +194,46 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
|
|||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If the condition format is invalid.
|
If the parameters are invalid or incompatible.
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
>>> @listen("process_data") # Listen to single method
|
>>> @listen(method="process_data") # Explicitly listen to method name
|
||||||
>>> def handle_processed_data(self):
|
>>> def handle_processed_data(self):
|
||||||
... pass
|
... pass
|
||||||
|
|
||||||
>>> @listen(or_("success", "failure")) # Listen to multiple methods
|
>>> @listen(output="success") # Explicitly listen to output string
|
||||||
|
>>> def handle_success_case(self):
|
||||||
|
... pass
|
||||||
|
|
||||||
|
>>> @listen(or_("success", "failure")) # Listen to multiple outputs
|
||||||
>>> def handle_completion(self):
|
>>> def handle_completion(self):
|
||||||
... pass
|
... pass
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
|
if method is not None and output is not None:
|
||||||
|
raise ValueError("Cannot specify both 'method' and 'output' parameters")
|
||||||
|
|
||||||
|
if method is not None:
|
||||||
|
func.__trigger_methods__ = [method]
|
||||||
|
func.__condition_type__ = "OR"
|
||||||
|
func.__is_output__ = False
|
||||||
|
return func
|
||||||
|
|
||||||
|
if output is not None:
|
||||||
|
func.__trigger_methods__ = [output]
|
||||||
|
func.__condition_type__ = "OR"
|
||||||
|
func.__is_output__ = True
|
||||||
|
return func
|
||||||
|
|
||||||
|
if condition is None:
|
||||||
|
raise ValueError("Must provide either 'condition', 'method', or 'output' parameter")
|
||||||
|
|
||||||
if isinstance(condition, str):
|
if isinstance(condition, str):
|
||||||
func.__trigger_methods__ = [condition]
|
func.__trigger_methods__ = [condition]
|
||||||
func.__condition_type__ = "OR"
|
func.__condition_type__ = "OR"
|
||||||
|
func.__is_output__ = None
|
||||||
elif (
|
elif (
|
||||||
isinstance(condition, dict)
|
isinstance(condition, dict)
|
||||||
and "type" in condition
|
and "type" in condition
|
||||||
@@ -210,9 +241,11 @@ def listen(condition: Union[str, dict, Callable]) -> Callable:
|
|||||||
):
|
):
|
||||||
func.__trigger_methods__ = condition["methods"]
|
func.__trigger_methods__ = condition["methods"]
|
||||||
func.__condition_type__ = condition["type"]
|
func.__condition_type__ = condition["type"]
|
||||||
|
func.__is_output__ = None
|
||||||
elif callable(condition) and hasattr(condition, "__name__"):
|
elif callable(condition) and hasattr(condition, "__name__"):
|
||||||
func.__trigger_methods__ = [condition.__name__]
|
func.__trigger_methods__ = [condition.__name__]
|
||||||
func.__condition_type__ = "OR"
|
func.__condition_type__ = "OR"
|
||||||
|
func.__is_output__ = False
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Condition must be a method, string, or a result of or_() or and_()"
|
"Condition must be a method, string, or a result of or_() or and_()"
|
||||||
@@ -897,10 +930,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
- Each router's result becomes a new trigger_method
|
- Each router's result becomes a new trigger_method
|
||||||
- Normal listeners are executed in parallel for efficiency
|
- Normal listeners are executed in parallel for efficiency
|
||||||
- Listeners can receive the trigger method's result as a parameter
|
- Listeners can receive the trigger method's result as a parameter
|
||||||
|
- Prevents infinite loops by tracking processed triggers
|
||||||
"""
|
"""
|
||||||
|
processed_triggers = set()
|
||||||
|
|
||||||
# First, handle routers repeatedly until no router triggers anymore
|
# First, handle routers repeatedly until no router triggers anymore
|
||||||
router_results = []
|
router_results = []
|
||||||
current_trigger = trigger_method
|
current_trigger = trigger_method
|
||||||
|
processed_triggers.add(current_trigger)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
routers_triggered = self._find_triggered_methods(
|
routers_triggered = self._find_triggered_methods(
|
||||||
@@ -913,11 +950,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
await self._execute_single_listener(router_name, result)
|
await self._execute_single_listener(router_name, result)
|
||||||
# After executing router, the router's result is the path
|
# After executing router, the router's result is the path
|
||||||
router_result = self._method_outputs[-1]
|
router_result = self._method_outputs[-1]
|
||||||
if router_result: # Only add non-None results
|
if router_result and router_result not in processed_triggers: # Only add non-None and unprocessed results
|
||||||
router_results.append(router_result)
|
router_results.append(router_result)
|
||||||
current_trigger = (
|
processed_triggers.add(router_result)
|
||||||
router_result # Update for next iteration of router chain
|
current_trigger = router_result # Update for next iteration of router chain
|
||||||
)
|
|
||||||
|
|
||||||
# Now execute normal listeners for all router results and the original trigger
|
# Now execute normal listeners for all router results and the original trigger
|
||||||
all_triggers = [trigger_method] + router_results
|
all_triggers = [trigger_method] + router_results
|
||||||
@@ -946,7 +982,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
trigger_method : str
|
trigger_method : str
|
||||||
The name of the method that just completed execution.
|
The name of the method that just completed execution or an output value.
|
||||||
router_only : bool
|
router_only : bool
|
||||||
If True, only consider router methods.
|
If True, only consider router methods.
|
||||||
If False, only consider non-router methods.
|
If False, only consider non-router methods.
|
||||||
@@ -963,6 +999,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
* AND: Triggers only when all conditions are met
|
* AND: Triggers only when all conditions are met
|
||||||
- Maintains state for AND conditions using _pending_and_listeners
|
- Maintains state for AND conditions using _pending_and_listeners
|
||||||
- Separates router and normal listener evaluation
|
- Separates router and normal listener evaluation
|
||||||
|
- Respects the __is_output__ attribute to disambiguate between method names and output strings
|
||||||
"""
|
"""
|
||||||
triggered = []
|
triggered = []
|
||||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||||
@@ -971,10 +1008,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if router_only != is_router:
|
if router_only != is_router:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
method = self._methods.get(listener_name)
|
||||||
|
is_output = getattr(method, "__is_output__", None)
|
||||||
|
|
||||||
if condition_type == "OR":
|
if condition_type == "OR":
|
||||||
# If the trigger_method matches any in methods, run this
|
# For methods with explicit output=True, only match if trigger is not a method name
|
||||||
|
# For methods with explicit method=True, only match if trigger is a method name
|
||||||
|
|
||||||
if trigger_method in methods:
|
if trigger_method in methods:
|
||||||
triggered.append(listener_name)
|
trigger_is_method = trigger_method in self._methods
|
||||||
|
|
||||||
|
if (is_output is None or # Legacy behavior - always match
|
||||||
|
(is_output is True and not trigger_is_method) or # Output string listener
|
||||||
|
(is_output is False and trigger_is_method)): # Method name listener
|
||||||
|
triggered.append(listener_name)
|
||||||
|
|
||||||
elif condition_type == "AND":
|
elif condition_type == "AND":
|
||||||
# Initialize pending methods for this listener if not already done
|
# Initialize pending methods for this listener if not already done
|
||||||
if listener_name not in self._pending_and_listeners:
|
if listener_name not in self._pending_and_listeners:
|
||||||
|
|||||||
168
tests/test_listen_decorator.py
Normal file
168
tests/test_listen_decorator.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Test @listen decorator for method vs output disambiguation."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from crewai.flow.flow import Flow, listen, router, start
|
||||||
|
|
||||||
|
|
||||||
|
def test_listen_with_explicit_method():
|
||||||
|
"""Test @listen with explicit method parameter."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class ExplicitFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def method_to_listen_for(self):
|
||||||
|
execution_order.append("method_to_listen_for")
|
||||||
|
return "method_output"
|
||||||
|
|
||||||
|
@listen(method="method_to_listen_for")
|
||||||
|
def explicit_method_listener(self):
|
||||||
|
execution_order.append("explicit_method_listener")
|
||||||
|
|
||||||
|
flow = ExplicitFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert "method_to_listen_for" in execution_order
|
||||||
|
assert "explicit_method_listener" in execution_order
|
||||||
|
assert execution_order.index("explicit_method_listener") > execution_order.index("method_to_listen_for")
|
||||||
|
|
||||||
|
|
||||||
|
def test_listen_with_explicit_output():
|
||||||
|
"""Test @listen with explicit output parameter."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class ExplicitOutputFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def start_method(self):
|
||||||
|
execution_order.append("start_method")
|
||||||
|
|
||||||
|
@router(start_method)
|
||||||
|
def router_method(self):
|
||||||
|
execution_order.append("router_method")
|
||||||
|
return "output_value"
|
||||||
|
|
||||||
|
@listen(output="output_value")
|
||||||
|
def output_listener(self):
|
||||||
|
execution_order.append("output_listener")
|
||||||
|
|
||||||
|
flow = ExplicitOutputFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert "start_method" in execution_order
|
||||||
|
assert "router_method" in execution_order
|
||||||
|
assert "output_listener" in execution_order
|
||||||
|
assert execution_order.index("output_listener") > execution_order.index("router_method")
|
||||||
|
|
||||||
|
|
||||||
|
def test_ambiguous_case_with_explicit_parameters():
|
||||||
|
"""Test case where method name matches a possible output value."""
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class AmbiguousFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def start_method(self):
|
||||||
|
print("Executing start_method")
|
||||||
|
execution_order.append("start_method")
|
||||||
|
return "start output"
|
||||||
|
|
||||||
|
@router(start_method)
|
||||||
|
def router_method(self):
|
||||||
|
print("Executing router_method")
|
||||||
|
execution_order.append("router_method")
|
||||||
|
return "ambiguous_name"
|
||||||
|
|
||||||
|
def ambiguous_name(self):
|
||||||
|
print("This method should not be called directly")
|
||||||
|
execution_order.append("ambiguous_name_direct_call")
|
||||||
|
return "should not happen"
|
||||||
|
|
||||||
|
@listen(method="ambiguous_name") # Listen to method name explicitly
|
||||||
|
def method_listener(self):
|
||||||
|
print("Executing method_listener")
|
||||||
|
execution_order.append("method_listener")
|
||||||
|
|
||||||
|
@listen(output="ambiguous_name") # Listen to output string explicitly
|
||||||
|
def output_listener(self):
|
||||||
|
print("Executing output_listener")
|
||||||
|
execution_order.append("output_listener")
|
||||||
|
|
||||||
|
print("Creating flow instance")
|
||||||
|
flow = AmbiguousFlow()
|
||||||
|
|
||||||
|
async def run_with_timeout():
|
||||||
|
task = asyncio.create_task(flow.kickoff_async())
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(task, timeout=5.0) # 5 second timeout
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("Test timed out - likely an infinite loop")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
print("Starting flow kickoff with timeout")
|
||||||
|
success = asyncio.run(run_with_timeout())
|
||||||
|
|
||||||
|
print(f"Execution order: {execution_order}")
|
||||||
|
|
||||||
|
if success:
|
||||||
|
assert "start_method" in execution_order
|
||||||
|
assert "router_method" in execution_order
|
||||||
|
assert "output_listener" in execution_order
|
||||||
|
|
||||||
|
assert "method_listener" not in execution_order
|
||||||
|
|
||||||
|
assert "ambiguous_name_direct_call" not in execution_order
|
||||||
|
|
||||||
|
assert execution_order.index("output_listener") > execution_order.index("router_method")
|
||||||
|
else:
|
||||||
|
pytest.fail("Test timed out - likely an infinite loop in the flow execution")
|
||||||
|
|
||||||
|
|
||||||
|
def test_listen_with_backward_compatibility():
|
||||||
|
"""Test that the old way of using @listen still works."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class BackwardCompatFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def start_method(self):
|
||||||
|
execution_order.append("start_method")
|
||||||
|
|
||||||
|
@router(start_method)
|
||||||
|
def router_method(self):
|
||||||
|
execution_order.append("router_method")
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
@listen("start_method") # Old way - listen to method by name
|
||||||
|
def method_listener(self):
|
||||||
|
execution_order.append("method_listener")
|
||||||
|
|
||||||
|
@listen("success") # Old way - listen to output string
|
||||||
|
def output_listener(self):
|
||||||
|
execution_order.append("output_listener")
|
||||||
|
|
||||||
|
flow = BackwardCompatFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert "start_method" in execution_order
|
||||||
|
assert "router_method" in execution_order
|
||||||
|
assert "method_listener" in execution_order
|
||||||
|
assert "output_listener" in execution_order
|
||||||
|
|
||||||
|
|
||||||
|
def test_listen_with_invalid_parameters():
|
||||||
|
"""Test that invalid parameters raise exceptions."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
@listen(method="method_name", output="output_value")
|
||||||
|
def invalid_listener(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
@listen()
|
||||||
|
def no_param_listener(self):
|
||||||
|
pass
|
||||||
Reference in New Issue
Block a user