mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Fix flows to support cycles and added in test
This commit is contained in:
@@ -131,7 +131,6 @@ class FlowMeta(type):
|
|||||||
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
||||||
listeners[attr_name] = (condition_type, methods)
|
listeners[attr_name] = (condition_type, methods)
|
||||||
|
|
||||||
# TODO: should we add a check for __condition_type__ 'AND'?
|
|
||||||
elif hasattr(attr_value, "__is_router__"):
|
elif hasattr(attr_value, "__is_router__"):
|
||||||
routers[attr_value.__router_for__] = attr_name
|
routers[attr_value.__router_for__] = attr_name
|
||||||
possible_returns = get_possible_return_constants(attr_value)
|
possible_returns = get_possible_return_constants(attr_value)
|
||||||
@@ -171,8 +170,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._methods: Dict[str, Callable] = {}
|
self._methods: Dict[str, Callable] = {}
|
||||||
self._state: T = self._create_initial_state()
|
self._state: T = self._create_initial_state()
|
||||||
self._executed_methods: Set[str] = set()
|
self._method_execution_counts: Dict[str, int] = {}
|
||||||
self._scheduled_tasks: Set[str] = set()
|
|
||||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||||
|
|
||||||
@@ -309,7 +307,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
)
|
)
|
||||||
self._method_outputs.append(result) # Store the output
|
self._method_outputs.append(result) # Store the output
|
||||||
|
|
||||||
self._executed_methods.add(method_name)
|
# Track method execution counts
|
||||||
|
self._method_execution_counts[method_name] = (
|
||||||
|
self._method_execution_counts.get(method_name, 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -319,35 +320,34 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if trigger_method in self._routers:
|
if trigger_method in self._routers:
|
||||||
router_method = self._methods[self._routers[trigger_method]]
|
router_method = self._methods[self._routers[trigger_method]]
|
||||||
path = await self._execute_method(
|
path = await self._execute_method(
|
||||||
trigger_method, router_method
|
self._routers[trigger_method], router_method
|
||||||
) # TODO: Change or not?
|
)
|
||||||
# Use the path as the new trigger method
|
|
||||||
trigger_method = path
|
trigger_method = path
|
||||||
|
|
||||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||||
if condition_type == "OR":
|
if condition_type == "OR":
|
||||||
if trigger_method in methods:
|
if trigger_method in methods:
|
||||||
if (
|
# Schedule the listener without preventing re-execution
|
||||||
listener_name not in self._executed_methods
|
listener_tasks.append(
|
||||||
and listener_name not in self._scheduled_tasks
|
self._execute_single_listener(listener_name, result)
|
||||||
):
|
)
|
||||||
self._scheduled_tasks.add(listener_name)
|
|
||||||
listener_tasks.append(
|
|
||||||
self._execute_single_listener(listener_name, result)
|
|
||||||
)
|
|
||||||
elif condition_type == "AND":
|
elif condition_type == "AND":
|
||||||
if all(method in self._executed_methods for method in methods):
|
# Initialize pending methods for this listener if not already done
|
||||||
if (
|
if listener_name not in self._pending_and_listeners:
|
||||||
listener_name not in self._executed_methods
|
self._pending_and_listeners[listener_name] = set(methods)
|
||||||
and listener_name not in self._scheduled_tasks
|
# Remove the trigger method from pending methods
|
||||||
):
|
self._pending_and_listeners[listener_name].discard(trigger_method)
|
||||||
self._scheduled_tasks.add(listener_name)
|
if not self._pending_and_listeners[listener_name]:
|
||||||
listener_tasks.append(
|
# All required methods have been executed
|
||||||
self._execute_single_listener(listener_name, result)
|
listener_tasks.append(
|
||||||
)
|
self._execute_single_listener(listener_name, result)
|
||||||
|
)
|
||||||
|
# Reset pending methods for this listener
|
||||||
|
self._pending_and_listeners.pop(listener_name, None)
|
||||||
|
|
||||||
# Run all listener tasks concurrently and wait for them to complete
|
# Run all listener tasks concurrently and wait for them to complete
|
||||||
await asyncio.gather(*listener_tasks)
|
if listener_tasks:
|
||||||
|
await asyncio.gather(*listener_tasks)
|
||||||
|
|
||||||
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -367,9 +367,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
# If listener does not expect parameters, call without arguments
|
# If listener does not expect parameters, call without arguments
|
||||||
listener_result = await self._execute_method(listener_name, method)
|
listener_result = await self._execute_method(listener_name, method)
|
||||||
|
|
||||||
# Remove from scheduled tasks after execution
|
|
||||||
self._scheduled_tasks.discard(listener_name)
|
|
||||||
|
|
||||||
# Execute listeners of this listener
|
# Execute listeners of this listener
|
||||||
await self._execute_listeners(listener_name, listener_result)
|
await self._execute_listeners(listener_name, listener_result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
264
tests/flow_test.py
Normal file
264
tests/flow_test.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
"""Test Flow creation and execution basic functionality."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_sequential_flow():
|
||||||
|
"""Test a simple flow with two steps called sequentially."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class SimpleFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
execution_order.append("step_1")
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self):
|
||||||
|
execution_order.append("step_2")
|
||||||
|
|
||||||
|
flow = SimpleFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert execution_order == ["step_1", "step_2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_with_multiple_starts():
|
||||||
|
"""Test a flow with multiple start methods."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class MultiStartFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def step_a(self):
|
||||||
|
execution_order.append("step_a")
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def step_b(self):
|
||||||
|
execution_order.append("step_b")
|
||||||
|
|
||||||
|
@listen(step_a)
|
||||||
|
def step_c(self):
|
||||||
|
execution_order.append("step_c")
|
||||||
|
|
||||||
|
@listen(step_b)
|
||||||
|
def step_d(self):
|
||||||
|
execution_order.append("step_d")
|
||||||
|
|
||||||
|
flow = MultiStartFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert "step_a" in execution_order
|
||||||
|
assert "step_b" in execution_order
|
||||||
|
assert "step_c" in execution_order
|
||||||
|
assert "step_d" in execution_order
|
||||||
|
assert execution_order.index("step_c") > execution_order.index("step_a")
|
||||||
|
assert execution_order.index("step_d") > execution_order.index("step_b")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cyclic_flow():
|
||||||
|
"""Test a cyclic flow that runs a finite number of iterations."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class CyclicFlow(Flow):
|
||||||
|
iteration = 0
|
||||||
|
max_iterations = 3
|
||||||
|
|
||||||
|
@start("loop")
|
||||||
|
def step_1(self):
|
||||||
|
if self.iteration >= self.max_iterations:
|
||||||
|
return # Do not proceed further
|
||||||
|
execution_order.append(f"step_1_{self.iteration}")
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self):
|
||||||
|
execution_order.append(f"step_2_{self.iteration}")
|
||||||
|
|
||||||
|
@router(step_2)
|
||||||
|
def step_3(self):
|
||||||
|
execution_order.append(f"step_3_{self.iteration}")
|
||||||
|
self.iteration += 1
|
||||||
|
if self.iteration < self.max_iterations:
|
||||||
|
return "loop"
|
||||||
|
|
||||||
|
return "exit"
|
||||||
|
|
||||||
|
flow = CyclicFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
expected_order = []
|
||||||
|
for i in range(flow.max_iterations):
|
||||||
|
expected_order.extend([f"step_1_{i}", f"step_2_{i}", f"step_3_{i}"])
|
||||||
|
|
||||||
|
assert execution_order == expected_order
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_with_and_condition():
|
||||||
|
"""Test a flow where a step waits for multiple other steps to complete."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class AndConditionFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
execution_order.append("step_1")
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def step_2(self):
|
||||||
|
execution_order.append("step_2")
|
||||||
|
|
||||||
|
@listen(and_(step_1, step_2))
|
||||||
|
def step_3(self):
|
||||||
|
execution_order.append("step_3")
|
||||||
|
|
||||||
|
flow = AndConditionFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert "step_1" in execution_order
|
||||||
|
assert "step_2" in execution_order
|
||||||
|
assert execution_order[-1] == "step_3"
|
||||||
|
assert execution_order.index("step_3") > execution_order.index("step_1")
|
||||||
|
assert execution_order.index("step_3") > execution_order.index("step_2")
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_with_or_condition():
|
||||||
|
"""Test a flow where a step is triggered when any of multiple steps complete."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class OrConditionFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def step_a(self):
|
||||||
|
execution_order.append("step_a")
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def step_b(self):
|
||||||
|
execution_order.append("step_b")
|
||||||
|
|
||||||
|
@listen(or_(step_a, step_b))
|
||||||
|
def step_c(self):
|
||||||
|
execution_order.append("step_c")
|
||||||
|
|
||||||
|
flow = OrConditionFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert "step_a" in execution_order or "step_b" in execution_order
|
||||||
|
assert "step_c" in execution_order
|
||||||
|
assert execution_order.index("step_c") > min(
|
||||||
|
execution_order.index("step_a"), execution_order.index("step_b")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_with_router():
|
||||||
|
"""Test a flow that uses a router method to determine the next step."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class RouterFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def start_method(self):
|
||||||
|
execution_order.append("start_method")
|
||||||
|
|
||||||
|
@router(start_method)
|
||||||
|
def router(self):
|
||||||
|
execution_order.append("router")
|
||||||
|
# Ensure the condition is set to True to follow the "step_if_true" path
|
||||||
|
condition = True
|
||||||
|
return "step_if_true" if condition else "step_if_false"
|
||||||
|
|
||||||
|
@listen("step_if_true")
|
||||||
|
def truthy(self):
|
||||||
|
execution_order.append("step_if_true")
|
||||||
|
|
||||||
|
@listen("step_if_false")
|
||||||
|
def falsy(self):
|
||||||
|
execution_order.append("step_if_false")
|
||||||
|
|
||||||
|
flow = RouterFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert execution_order == ["start_method", "router", "step_if_true"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_flow():
|
||||||
|
"""Test an asynchronous flow."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class AsyncFlow(Flow):
|
||||||
|
@start()
|
||||||
|
async def step_1(self):
|
||||||
|
execution_order.append("step_1")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
async def step_2(self):
|
||||||
|
execution_order.append("step_2")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
flow = AsyncFlow()
|
||||||
|
asyncio.run(flow.kickoff_async())
|
||||||
|
|
||||||
|
assert execution_order == ["step_1", "step_2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_with_exceptions():
|
||||||
|
"""Test flow behavior when exceptions occur in steps."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class ExceptionFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
execution_order.append("step_1")
|
||||||
|
raise ValueError("An error occurred in step_1")
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self):
|
||||||
|
execution_order.append("step_2")
|
||||||
|
|
||||||
|
flow = ExceptionFlow()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
# Ensure step_2 did not execute
|
||||||
|
assert execution_order == ["step_1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_restart():
|
||||||
|
"""Test restarting a flow after it has completed."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class RestartableFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
execution_order.append("step_1")
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self):
|
||||||
|
execution_order.append("step_2")
|
||||||
|
|
||||||
|
flow = RestartableFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
flow.kickoff() # Restart the flow
|
||||||
|
|
||||||
|
assert execution_order == ["step_1", "step_2", "step_1", "step_2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_with_custom_state():
|
||||||
|
"""Test a flow that maintains and modifies internal state."""
|
||||||
|
|
||||||
|
class StateFlow(Flow):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
self.counter += 1
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self):
|
||||||
|
self.counter *= 2
|
||||||
|
assert self.counter == 2
|
||||||
|
|
||||||
|
flow = StateFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
assert flow.counter == 2
|
||||||
Reference in New Issue
Block a user