Add in or and and in router

This commit is contained in:
Brandon Hancock
2024-12-16 12:10:33 -05:00
parent 6d7c1b0743
commit 8eea6bc090
2 changed files with 130 additions and 84 deletions

View File

@@ -80,10 +80,27 @@ def listen(condition):
return decorator return decorator
def router(method): def router(condition):
def decorator(func): def decorator(func):
func.__is_router__ = True func.__is_router__ = True
func.__router_for__ = method.__name__ # Handle conditions like listen/start
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
else:
raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()"
)
return func return func
return decorator return decorator
@@ -123,8 +140,8 @@ class FlowMeta(type):
start_methods = [] start_methods = []
listeners = {} listeners = {}
routers = {}
router_paths = {} router_paths = {}
routers = set()
for attr_name, attr_value in dct.items(): for attr_name, attr_value in dct.items():
if hasattr(attr_value, "__is_start_method__"): if hasattr(attr_value, "__is_start_method__"):
@@ -137,18 +154,11 @@ class FlowMeta(type):
methods = attr_value.__trigger_methods__ methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR") condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods) listeners[attr_name] = (condition_type, methods)
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
elif hasattr(attr_value, "__is_router__"): routers.add(attr_name)
routers[attr_value.__router_for__] = attr_name possible_returns = get_possible_return_constants(attr_value)
possible_returns = get_possible_return_constants(attr_value) if possible_returns:
if possible_returns: router_paths[attr_name] = possible_returns
router_paths[attr_name] = possible_returns
# Register router as a listener to its triggering method
trigger_method_name = attr_value.__router_for__
methods = [trigger_method_name]
condition_type = "OR"
listeners[attr_name] = (condition_type, methods)
setattr(cls, "_start_methods", start_methods) setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners) setattr(cls, "_listeners", listeners)
@@ -163,7 +173,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
_start_methods: List[str] = [] _start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {} _listeners: Dict[str, tuple[str, List[str]]] = {}
_routers: Dict[str, str] = {} _routers: Set[str] = set()
_router_paths: Dict[str, List[str]] = {} _router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None initial_state: Union[Type[T], T, None] = None
event_emitter = Signal("event_emitter") event_emitter = Signal("event_emitter")
@@ -210,20 +220,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
return self._method_outputs return self._method_outputs
def _initialize_state(self, inputs: Dict[str, Any]) -> None: def _initialize_state(self, inputs: Dict[str, Any]) -> None:
"""
Initializes or updates the state with the provided inputs.
Args:
inputs: Dictionary of inputs to initialize or update the state.
Raises:
ValueError: If inputs do not match the structured state model.
TypeError: If state is neither a BaseModel instance nor a dictionary.
"""
if isinstance(self._state, BaseModel): if isinstance(self._state, BaseModel):
# Structured state management # Structured state
try: try:
# Define a function to create the dynamic class
def create_model_with_extra_forbid( def create_model_with_extra_forbid(
base_model: Type[BaseModel], base_model: Type[BaseModel],
) -> Type[BaseModel]: ) -> Type[BaseModel]:
@@ -233,34 +233,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
return ModelWithExtraForbid return ModelWithExtraForbid
# Create the dynamic class
ModelWithExtraForbid = create_model_with_extra_forbid( ModelWithExtraForbid = create_model_with_extra_forbid(
self._state.__class__ self._state.__class__
) )
# Create a new instance using the combined state and inputs
self._state = cast( self._state = cast(
T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs}) T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs})
) )
except ValidationError as e: except ValidationError as e:
raise ValueError(f"Invalid inputs for structured state: {e}") from e raise ValueError(f"Invalid inputs for structured state: {e}") from e
elif isinstance(self._state, dict): elif isinstance(self._state, dict):
# Unstructured state management
self._state.update(inputs) self._state.update(inputs)
else: else:
raise TypeError("State must be a BaseModel instance or a dictionary.") raise TypeError("State must be a BaseModel instance or a dictionary.")
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Starts the execution of the flow synchronously.
Args:
inputs: Optional dictionary of inputs to initialize or update the state.
Returns:
The final output from the flow execution.
"""
self.event_emitter.send( self.event_emitter.send(
self, self,
event=FlowStartedEvent( event=FlowStartedEvent(
@@ -274,15 +260,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
return asyncio.run(self.kickoff_async()) return asyncio.run(self.kickoff_async())
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any: async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Starts the execution of the flow asynchronously.
Args:
inputs: Optional dictionary of inputs to initialize or update the state.
Returns:
The final output from the flow execution.
"""
if not self._start_methods: if not self._start_methods:
raise ValueError("No start method defined") raise ValueError("No start method defined")
@@ -290,16 +267,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
self.__class__.__name__, list(self._methods.keys()) self.__class__.__name__, list(self._methods.keys())
) )
# Create tasks for all start methods
tasks = [ tasks = [
self._execute_start_method(start_method) self._execute_start_method(start_method)
for start_method in self._start_methods for start_method in self._start_methods
] ]
# Run all start methods concurrently
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# Determine the final output (from the last executed method)
final_output = self._method_outputs[-1] if self._method_outputs else None final_output = self._method_outputs[-1] if self._method_outputs else None
self.event_emitter.send( self.event_emitter.send(
@@ -310,7 +283,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
result=final_output, result=final_output,
), ),
) )
return final_output return final_output
async def _execute_start_method(self, start_method_name: str) -> None: async def _execute_start_method(self, start_method_name: str) -> None:
@@ -327,49 +299,68 @@ class Flow(Generic[T], metaclass=FlowMeta):
if asyncio.iscoroutinefunction(method) if asyncio.iscoroutinefunction(method)
else method(*args, **kwargs) else method(*args, **kwargs)
) )
self._method_outputs.append(result) # Store the output self._method_outputs.append(result)
# Track method execution counts
self._method_execution_counts[method_name] = ( self._method_execution_counts[method_name] = (
self._method_execution_counts.get(method_name, 0) + 1 self._method_execution_counts.get(method_name, 0) + 1
) )
return result return result
async def _execute_listeners(self, trigger_method: str, result: Any) -> None: async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
listener_tasks = [] # First, handle routers repeatedly until no router triggers anymore
while True:
if trigger_method in self._routers: routers_triggered = self._find_triggered_methods(
router_method = self._methods[self._routers[trigger_method]] trigger_method, router_only=True
path = await self._execute_method(
self._routers[trigger_method], router_method
) )
trigger_method = path if not routers_triggered:
break
for router_name in routers_triggered:
await self._execute_single_listener(router_name, result)
# After executing router, the router's result is the path
# The last router executed sets the trigger_method
# The router result is the last element in self._method_outputs
trigger_method = self._method_outputs[-1]
# Now that no more routers are triggered by current trigger_method,
# execute normal listeners
listeners_triggered = self._find_triggered_methods(
trigger_method, router_only=False
)
if listeners_triggered:
tasks = [
self._execute_single_listener(listener_name, result)
for listener_name in listeners_triggered
]
await asyncio.gather(*tasks)
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> List[str]:
triggered = []
for listener_name, (condition_type, methods) in self._listeners.items(): for listener_name, (condition_type, methods) in self._listeners.items():
is_router = listener_name in self._routers
if router_only != is_router:
continue
if condition_type == "OR": if condition_type == "OR":
# If the trigger_method matches any in methods, run this
if trigger_method in methods: if trigger_method in methods:
# Schedule the listener without preventing re-execution triggered.append(listener_name)
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
elif condition_type == "AND": elif condition_type == "AND":
# Initialize pending methods for this listener if not already done # Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners: if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods) self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods # Remove the trigger method from pending methods
self._pending_and_listeners[listener_name].discard(trigger_method) if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(trigger_method)
if not self._pending_and_listeners[listener_name]: if not self._pending_and_listeners[listener_name]:
# All required methods have been executed # All required methods have been executed
listener_tasks.append( triggered.append(listener_name)
self._execute_single_listener(listener_name, result)
)
# Reset pending methods for this listener # Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None) self._pending_and_listeners.pop(listener_name, None)
# Run all listener tasks concurrently and wait for them to complete return triggered
if listener_tasks:
await asyncio.gather(*listener_tasks)
async def _execute_single_listener(self, listener_name: str, result: Any) -> None: async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
try: try:
@@ -386,17 +377,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
sig = inspect.signature(method) sig = inspect.signature(method)
params = list(sig.parameters.values()) params = list(sig.parameters.values())
# Exclude 'self' parameter
method_params = [p for p in params if p.name != "self"] method_params = [p for p in params if p.name != "self"]
if method_params: if method_params:
# If listener expects parameters, pass the result
listener_result = await self._execute_method( listener_result = await self._execute_method(
listener_name, method, result listener_name, method, result
) )
else: else:
# If listener does not expect parameters, call without arguments
listener_result = await self._execute_method(listener_name, method) listener_result = await self._execute_method(listener_name, method)
self.event_emitter.send( self.event_emitter.send(
@@ -408,8 +395,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
), ),
) )
# Execute listeners of this listener # Execute listeners (and possibly routers) of this listener
await self._execute_listeners(listener_name, listener_result) await self._execute_listeners(listener_name, listener_result)
except Exception as e: except Exception as e:
print( print(
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}" f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
@@ -422,5 +410,4 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._telemetry.flow_plotting_span( self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys()) self.__class__.__name__, list(self._methods.keys())
) )
plot_flow(self, filename) plot_flow(self, filename)

View File

@@ -263,3 +263,62 @@ def test_flow_with_custom_state():
flow = StateFlow() flow = StateFlow()
flow.kickoff() flow.kickoff()
assert flow.counter == 2 assert flow.counter == 2
def test_router_with_multiple_conditions():
"""Test a router that triggers when any of multiple steps complete (OR condition),
and another router that triggers only after all specified steps complete (AND condition).
"""
execution_order = []
class ComplexRouterFlow(Flow):
@start()
def step_a(self):
execution_order.append("step_a")
@start()
def step_b(self):
execution_order.append("step_b")
@router(or_("step_a", "step_b"))
def router_or(self):
execution_order.append("router_or")
return "next_step_or"
@listen("next_step_or")
def handle_next_step_or_event(self):
execution_order.append("handle_next_step_or_event")
@listen(handle_next_step_or_event)
def branch_2_step(self):
execution_order.append("branch_2_step")
@router(and_(handle_next_step_or_event, branch_2_step))
def router_and(self):
execution_order.append("router_and")
return "final_step"
@listen("final_step")
def log_final_step(self):
execution_order.append("log_final_step")
flow = ComplexRouterFlow()
flow.kickoff()
assert "step_a" in execution_order
assert "step_b" in execution_order
assert "router_or" in execution_order
assert "handle_next_step_or_event" in execution_order
assert "branch_2_step" in execution_order
assert "router_and" in execution_order
assert "log_final_step" in execution_order
# Check that the AND router triggered after both relevant steps:
assert execution_order.index("router_and") > execution_order.index(
"handle_next_step_or_event"
)
assert execution_order.index("router_and") > execution_order.index("branch_2_step")
# final_step should run after router_and
assert execution_order.index("log_final_step") > execution_order.index("router_and")