mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Add in or and and in router
This commit is contained in:
@@ -80,10 +80,27 @@ def listen(condition):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def router(method):
|
def router(condition):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
func.__is_router__ = True
|
func.__is_router__ = True
|
||||||
func.__router_for__ = method.__name__
|
# Handle conditions like listen/start
|
||||||
|
if isinstance(condition, str):
|
||||||
|
func.__trigger_methods__ = [condition]
|
||||||
|
func.__condition_type__ = "OR"
|
||||||
|
elif (
|
||||||
|
isinstance(condition, dict)
|
||||||
|
and "type" in condition
|
||||||
|
and "methods" in condition
|
||||||
|
):
|
||||||
|
func.__trigger_methods__ = condition["methods"]
|
||||||
|
func.__condition_type__ = condition["type"]
|
||||||
|
elif callable(condition) and hasattr(condition, "__name__"):
|
||||||
|
func.__trigger_methods__ = [condition.__name__]
|
||||||
|
func.__condition_type__ = "OR"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Condition must be a method, string, or a result of or_() or and_()"
|
||||||
|
)
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -123,8 +140,8 @@ class FlowMeta(type):
|
|||||||
|
|
||||||
start_methods = []
|
start_methods = []
|
||||||
listeners = {}
|
listeners = {}
|
||||||
routers = {}
|
|
||||||
router_paths = {}
|
router_paths = {}
|
||||||
|
routers = set()
|
||||||
|
|
||||||
for attr_name, attr_value in dct.items():
|
for attr_name, attr_value in dct.items():
|
||||||
if hasattr(attr_value, "__is_start_method__"):
|
if hasattr(attr_value, "__is_start_method__"):
|
||||||
@@ -137,18 +154,11 @@ class FlowMeta(type):
|
|||||||
methods = attr_value.__trigger_methods__
|
methods = attr_value.__trigger_methods__
|
||||||
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
||||||
listeners[attr_name] = (condition_type, methods)
|
listeners[attr_name] = (condition_type, methods)
|
||||||
|
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
|
||||||
elif hasattr(attr_value, "__is_router__"):
|
routers.add(attr_name)
|
||||||
routers[attr_value.__router_for__] = attr_name
|
possible_returns = get_possible_return_constants(attr_value)
|
||||||
possible_returns = get_possible_return_constants(attr_value)
|
if possible_returns:
|
||||||
if possible_returns:
|
router_paths[attr_name] = possible_returns
|
||||||
router_paths[attr_name] = possible_returns
|
|
||||||
|
|
||||||
# Register router as a listener to its triggering method
|
|
||||||
trigger_method_name = attr_value.__router_for__
|
|
||||||
methods = [trigger_method_name]
|
|
||||||
condition_type = "OR"
|
|
||||||
listeners[attr_name] = (condition_type, methods)
|
|
||||||
|
|
||||||
setattr(cls, "_start_methods", start_methods)
|
setattr(cls, "_start_methods", start_methods)
|
||||||
setattr(cls, "_listeners", listeners)
|
setattr(cls, "_listeners", listeners)
|
||||||
@@ -163,7 +173,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
_start_methods: List[str] = []
|
_start_methods: List[str] = []
|
||||||
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
||||||
_routers: Dict[str, str] = {}
|
_routers: Set[str] = set()
|
||||||
_router_paths: Dict[str, List[str]] = {}
|
_router_paths: Dict[str, List[str]] = {}
|
||||||
initial_state: Union[Type[T], T, None] = None
|
initial_state: Union[Type[T], T, None] = None
|
||||||
event_emitter = Signal("event_emitter")
|
event_emitter = Signal("event_emitter")
|
||||||
@@ -210,20 +220,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
return self._method_outputs
|
return self._method_outputs
|
||||||
|
|
||||||
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
||||||
"""
|
|
||||||
Initializes or updates the state with the provided inputs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Dictionary of inputs to initialize or update the state.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If inputs do not match the structured state model.
|
|
||||||
TypeError: If state is neither a BaseModel instance nor a dictionary.
|
|
||||||
"""
|
|
||||||
if isinstance(self._state, BaseModel):
|
if isinstance(self._state, BaseModel):
|
||||||
# Structured state management
|
# Structured state
|
||||||
try:
|
try:
|
||||||
# Define a function to create the dynamic class
|
|
||||||
def create_model_with_extra_forbid(
|
def create_model_with_extra_forbid(
|
||||||
base_model: Type[BaseModel],
|
base_model: Type[BaseModel],
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
@@ -233,34 +233,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
return ModelWithExtraForbid
|
return ModelWithExtraForbid
|
||||||
|
|
||||||
# Create the dynamic class
|
|
||||||
ModelWithExtraForbid = create_model_with_extra_forbid(
|
ModelWithExtraForbid = create_model_with_extra_forbid(
|
||||||
self._state.__class__
|
self._state.__class__
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a new instance using the combined state and inputs
|
|
||||||
self._state = cast(
|
self._state = cast(
|
||||||
T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs})
|
T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs})
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
||||||
elif isinstance(self._state, dict):
|
elif isinstance(self._state, dict):
|
||||||
# Unstructured state management
|
|
||||||
self._state.update(inputs)
|
self._state.update(inputs)
|
||||||
else:
|
else:
|
||||||
raise TypeError("State must be a BaseModel instance or a dictionary.")
|
raise TypeError("State must be a BaseModel instance or a dictionary.")
|
||||||
|
|
||||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||||
"""
|
|
||||||
Starts the execution of the flow synchronously.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Optional dictionary of inputs to initialize or update the state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The final output from the flow execution.
|
|
||||||
"""
|
|
||||||
self.event_emitter.send(
|
self.event_emitter.send(
|
||||||
self,
|
self,
|
||||||
event=FlowStartedEvent(
|
event=FlowStartedEvent(
|
||||||
@@ -274,15 +260,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
return asyncio.run(self.kickoff_async())
|
return asyncio.run(self.kickoff_async())
|
||||||
|
|
||||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||||
"""
|
|
||||||
Starts the execution of the flow asynchronously.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs: Optional dictionary of inputs to initialize or update the state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The final output from the flow execution.
|
|
||||||
"""
|
|
||||||
if not self._start_methods:
|
if not self._start_methods:
|
||||||
raise ValueError("No start method defined")
|
raise ValueError("No start method defined")
|
||||||
|
|
||||||
@@ -290,16 +267,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self.__class__.__name__, list(self._methods.keys())
|
self.__class__.__name__, list(self._methods.keys())
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create tasks for all start methods
|
|
||||||
tasks = [
|
tasks = [
|
||||||
self._execute_start_method(start_method)
|
self._execute_start_method(start_method)
|
||||||
for start_method in self._start_methods
|
for start_method in self._start_methods
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run all start methods concurrently
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# Determine the final output (from the last executed method)
|
|
||||||
final_output = self._method_outputs[-1] if self._method_outputs else None
|
final_output = self._method_outputs[-1] if self._method_outputs else None
|
||||||
|
|
||||||
self.event_emitter.send(
|
self.event_emitter.send(
|
||||||
@@ -310,7 +283,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
result=final_output,
|
result=final_output,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return final_output
|
return final_output
|
||||||
|
|
||||||
async def _execute_start_method(self, start_method_name: str) -> None:
|
async def _execute_start_method(self, start_method_name: str) -> None:
|
||||||
@@ -327,49 +299,68 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if asyncio.iscoroutinefunction(method)
|
if asyncio.iscoroutinefunction(method)
|
||||||
else method(*args, **kwargs)
|
else method(*args, **kwargs)
|
||||||
)
|
)
|
||||||
self._method_outputs.append(result) # Store the output
|
self._method_outputs.append(result)
|
||||||
|
|
||||||
# Track method execution counts
|
|
||||||
self._method_execution_counts[method_name] = (
|
self._method_execution_counts[method_name] = (
|
||||||
self._method_execution_counts.get(method_name, 0) + 1
|
self._method_execution_counts.get(method_name, 0) + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
|
||||||
listener_tasks = []
|
# First, handle routers repeatedly until no router triggers anymore
|
||||||
|
while True:
|
||||||
if trigger_method in self._routers:
|
routers_triggered = self._find_triggered_methods(
|
||||||
router_method = self._methods[self._routers[trigger_method]]
|
trigger_method, router_only=True
|
||||||
path = await self._execute_method(
|
|
||||||
self._routers[trigger_method], router_method
|
|
||||||
)
|
)
|
||||||
trigger_method = path
|
if not routers_triggered:
|
||||||
|
break
|
||||||
|
for router_name in routers_triggered:
|
||||||
|
await self._execute_single_listener(router_name, result)
|
||||||
|
# After executing router, the router's result is the path
|
||||||
|
# The last router executed sets the trigger_method
|
||||||
|
# The router result is the last element in self._method_outputs
|
||||||
|
trigger_method = self._method_outputs[-1]
|
||||||
|
|
||||||
|
# Now that no more routers are triggered by current trigger_method,
|
||||||
|
# execute normal listeners
|
||||||
|
listeners_triggered = self._find_triggered_methods(
|
||||||
|
trigger_method, router_only=False
|
||||||
|
)
|
||||||
|
if listeners_triggered:
|
||||||
|
tasks = [
|
||||||
|
self._execute_single_listener(listener_name, result)
|
||||||
|
for listener_name in listeners_triggered
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
def _find_triggered_methods(
|
||||||
|
self, trigger_method: str, router_only: bool
|
||||||
|
) -> List[str]:
|
||||||
|
triggered = []
|
||||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||||
|
is_router = listener_name in self._routers
|
||||||
|
|
||||||
|
if router_only != is_router:
|
||||||
|
continue
|
||||||
|
|
||||||
if condition_type == "OR":
|
if condition_type == "OR":
|
||||||
|
# If the trigger_method matches any in methods, run this
|
||||||
if trigger_method in methods:
|
if trigger_method in methods:
|
||||||
# Schedule the listener without preventing re-execution
|
triggered.append(listener_name)
|
||||||
listener_tasks.append(
|
|
||||||
self._execute_single_listener(listener_name, result)
|
|
||||||
)
|
|
||||||
elif condition_type == "AND":
|
elif condition_type == "AND":
|
||||||
# Initialize pending methods for this listener if not already done
|
# Initialize pending methods for this listener if not already done
|
||||||
if listener_name not in self._pending_and_listeners:
|
if listener_name not in self._pending_and_listeners:
|
||||||
self._pending_and_listeners[listener_name] = set(methods)
|
self._pending_and_listeners[listener_name] = set(methods)
|
||||||
# Remove the trigger method from pending methods
|
# Remove the trigger method from pending methods
|
||||||
self._pending_and_listeners[listener_name].discard(trigger_method)
|
if trigger_method in self._pending_and_listeners[listener_name]:
|
||||||
|
self._pending_and_listeners[listener_name].discard(trigger_method)
|
||||||
|
|
||||||
if not self._pending_and_listeners[listener_name]:
|
if not self._pending_and_listeners[listener_name]:
|
||||||
# All required methods have been executed
|
# All required methods have been executed
|
||||||
listener_tasks.append(
|
triggered.append(listener_name)
|
||||||
self._execute_single_listener(listener_name, result)
|
|
||||||
)
|
|
||||||
# Reset pending methods for this listener
|
# Reset pending methods for this listener
|
||||||
self._pending_and_listeners.pop(listener_name, None)
|
self._pending_and_listeners.pop(listener_name, None)
|
||||||
|
|
||||||
# Run all listener tasks concurrently and wait for them to complete
|
return triggered
|
||||||
if listener_tasks:
|
|
||||||
await asyncio.gather(*listener_tasks)
|
|
||||||
|
|
||||||
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -386,17 +377,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
sig = inspect.signature(method)
|
sig = inspect.signature(method)
|
||||||
params = list(sig.parameters.values())
|
params = list(sig.parameters.values())
|
||||||
|
|
||||||
# Exclude 'self' parameter
|
|
||||||
method_params = [p for p in params if p.name != "self"]
|
method_params = [p for p in params if p.name != "self"]
|
||||||
|
|
||||||
if method_params:
|
if method_params:
|
||||||
# If listener expects parameters, pass the result
|
|
||||||
listener_result = await self._execute_method(
|
listener_result = await self._execute_method(
|
||||||
listener_name, method, result
|
listener_name, method, result
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# If listener does not expect parameters, call without arguments
|
|
||||||
listener_result = await self._execute_method(listener_name, method)
|
listener_result = await self._execute_method(listener_name, method)
|
||||||
|
|
||||||
self.event_emitter.send(
|
self.event_emitter.send(
|
||||||
@@ -408,8 +395,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute listeners of this listener
|
# Execute listeners (and possibly routers) of this listener
|
||||||
await self._execute_listeners(listener_name, listener_result)
|
await self._execute_listeners(listener_name, listener_result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
|
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
|
||||||
@@ -422,5 +410,4 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self._telemetry.flow_plotting_span(
|
self._telemetry.flow_plotting_span(
|
||||||
self.__class__.__name__, list(self._methods.keys())
|
self.__class__.__name__, list(self._methods.keys())
|
||||||
)
|
)
|
||||||
|
|
||||||
plot_flow(self, filename)
|
plot_flow(self, filename)
|
||||||
|
|||||||
@@ -263,3 +263,62 @@ def test_flow_with_custom_state():
|
|||||||
flow = StateFlow()
|
flow = StateFlow()
|
||||||
flow.kickoff()
|
flow.kickoff()
|
||||||
assert flow.counter == 2
|
assert flow.counter == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_with_multiple_conditions():
|
||||||
|
"""Test a router that triggers when any of multiple steps complete (OR condition),
|
||||||
|
and another router that triggers only after all specified steps complete (AND condition).
|
||||||
|
"""
|
||||||
|
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class ComplexRouterFlow(Flow):
|
||||||
|
@start()
|
||||||
|
def step_a(self):
|
||||||
|
execution_order.append("step_a")
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def step_b(self):
|
||||||
|
execution_order.append("step_b")
|
||||||
|
|
||||||
|
@router(or_("step_a", "step_b"))
|
||||||
|
def router_or(self):
|
||||||
|
execution_order.append("router_or")
|
||||||
|
return "next_step_or"
|
||||||
|
|
||||||
|
@listen("next_step_or")
|
||||||
|
def handle_next_step_or_event(self):
|
||||||
|
execution_order.append("handle_next_step_or_event")
|
||||||
|
|
||||||
|
@listen(handle_next_step_or_event)
|
||||||
|
def branch_2_step(self):
|
||||||
|
execution_order.append("branch_2_step")
|
||||||
|
|
||||||
|
@router(and_(handle_next_step_or_event, branch_2_step))
|
||||||
|
def router_and(self):
|
||||||
|
execution_order.append("router_and")
|
||||||
|
return "final_step"
|
||||||
|
|
||||||
|
@listen("final_step")
|
||||||
|
def log_final_step(self):
|
||||||
|
execution_order.append("log_final_step")
|
||||||
|
|
||||||
|
flow = ComplexRouterFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert "step_a" in execution_order
|
||||||
|
assert "step_b" in execution_order
|
||||||
|
assert "router_or" in execution_order
|
||||||
|
assert "handle_next_step_or_event" in execution_order
|
||||||
|
assert "branch_2_step" in execution_order
|
||||||
|
assert "router_and" in execution_order
|
||||||
|
assert "log_final_step" in execution_order
|
||||||
|
|
||||||
|
# Check that the AND router triggered after both relevant steps:
|
||||||
|
assert execution_order.index("router_and") > execution_order.index(
|
||||||
|
"handle_next_step_or_event"
|
||||||
|
)
|
||||||
|
assert execution_order.index("router_and") > execution_order.index("branch_2_step")
|
||||||
|
|
||||||
|
# final_step should run after router_and
|
||||||
|
assert execution_order.index("log_final_step") > execution_order.index("router_and")
|
||||||
|
|||||||
Reference in New Issue
Block a user