From 60d43c95b95b6dcd39e6dfe5416895cecfc606ea Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Tue, 29 Oct 2024 17:44:07 -0400 Subject: [PATCH] bugfix/flows-with-multiple-starts-plus-ands-breaking --- src/crewai/flow/flow.py | 79 ++++++++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 24 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index de9b2eeb2..982f14df0 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -120,6 +120,8 @@ class FlowMeta(type): methods = attr_value.__trigger_methods__ condition_type = getattr(attr_value, "__condition_type__", "OR") listeners[attr_name] = (condition_type, methods) + + # TODO: should we add a check for __condition_type__ 'AND'? elif hasattr(attr_value, "__is_router__"): routers[attr_value.__router_for__] = attr_name possible_returns = get_possible_return_constants(attr_value) @@ -137,6 +139,11 @@ class FlowMeta(type): setattr(cls, "_routers", routers) setattr(cls, "_router_paths", router_paths) + print("_start_methods", start_methods) + print("_listeners", listeners) + print("_routers", routers) + print("_router_paths", router_paths) + return cls @@ -159,7 +166,8 @@ class Flow(Generic[T], metaclass=FlowMeta): def __init__(self) -> None: self._methods: Dict[str, Callable] = {} self._state: T = self._create_initial_state() - self._completed_methods: Set[str] = set() + self._executed_methods: Set[str] = set() + self._scheduled_tasks: Set[str] = set() self._pending_and_listeners: Dict[str, Set[str]] = {} self._method_outputs: List[Any] = [] # List to store all method outputs @@ -169,6 +177,7 @@ class Flow(Generic[T], metaclass=FlowMeta): if callable(getattr(self, method_name)) and not method_name.startswith( "__" ): + print("Method", method_name) self._methods[method_name] = getattr(self, method_name) def _create_initial_state(self) -> T: @@ -216,17 +225,24 @@ class Flow(Generic[T], metaclass=FlowMeta): else: return None # Or raise an exception if no methods were executed - async def _execute_start_method(self, start_method: str) -> None: - result = await self._execute_method(self._methods[start_method]) - await self._execute_listeners(start_method, result) + async def _execute_start_method(self, start_method_name: str) -> None: + result = await self._execute_method( + start_method_name, self._methods[start_method_name] + ) + await self._execute_listeners(start_method_name, result) - async def _execute_method(self, method: Callable, *args: Any, **kwargs: Any) -> Any: + async def _execute_method( + self, method_name: str, method: Callable, *args: Any, **kwargs: Any + ) -> Any: result = ( await method(*args, **kwargs) if asyncio.iscoroutinefunction(method) else method(*args, **kwargs) ) self._method_outputs.append(result) # Store the output + + self._executed_methods.add(method_name) + return result async def _execute_listeners(self, trigger_method: str, result: Any) -> None: @@ -234,32 +250,40 @@ class Flow(Generic[T], metaclass=FlowMeta): if trigger_method in self._routers: router_method = self._methods[self._routers[trigger_method]] - path = await self._execute_method(router_method) + path = await self._execute_method( + trigger_method, router_method + ) # TODO: Change or not? # Use the path as the new trigger method trigger_method = path - for listener, (condition_type, methods) in self._listeners.items(): + for listener_name, (condition_type, methods) in self._listeners.items(): if condition_type == "OR": if trigger_method in methods: - listener_tasks.append( - self._execute_single_listener(listener, result) - ) + if ( + listener_name not in self._executed_methods + and listener_name not in self._scheduled_tasks + ): + self._scheduled_tasks.add(listener_name) + listener_tasks.append( + self._execute_single_listener(listener_name, result) + ) elif condition_type == "AND": - if listener not in self._pending_and_listeners: - self._pending_and_listeners[listener] = set() - self._pending_and_listeners[listener].add(trigger_method) - if set(methods) == self._pending_and_listeners[listener]: - listener_tasks.append( - self._execute_single_listener(listener, result) - ) - del self._pending_and_listeners[listener] + if all(method in self._executed_methods for method in methods): + if ( + listener_name not in self._executed_methods + and listener_name not in self._scheduled_tasks + ): + self._scheduled_tasks.add(listener_name) + listener_tasks.append( + self._execute_single_listener(listener_name, result) + ) # Run all listener tasks concurrently and wait for them to complete await asyncio.gather(*listener_tasks) - async def _execute_single_listener(self, listener: str, result: Any) -> None: + async def _execute_single_listener(self, listener_name: str, result: Any) -> None: try: - method = self._methods[listener] + method = self._methods[listener_name] sig = inspect.signature(method) params = list(sig.parameters.values()) @@ -268,15 +292,22 @@ class Flow(Generic[T], metaclass=FlowMeta): if method_params: # If listener expects parameters, pass the result - listener_result = await self._execute_method(method, result) + listener_result = await self._execute_method( + listener_name, method, result + ) else: # If listener does not expect parameters, call without arguments - listener_result = await self._execute_method(method) + listener_result = await self._execute_method(listener_name, method) + + # Remove from scheduled tasks after execution + self._scheduled_tasks.discard(listener_name) # Execute listeners of this listener - await self._execute_listeners(listener, listener_result) + await self._execute_listeners(listener_name, listener_result) except Exception as e: - print(f"[Flow._execute_single_listener] Error in method {listener}: {e}") + print( + f"[Flow._execute_single_listener] Error in method {listener_name}: {e}" + ) import traceback traceback.print_exc()