From cdfbd5f62b23ed8753a48bd8df1a0ef4dbf68363 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Tue, 29 Oct 2024 18:36:53 -0400 Subject: [PATCH] Bugfix/flows with multiple starts plus ands breaking (#1531) * bugfix/flows-with-multiple-starts-plus-ands-breaking * fix user found issue * remove prints --- src/crewai/flow/flow.py | 75 +++++++++++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 26 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index de9b2eeb2..e7231e13f 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,5 +1,3 @@ -# flow.py - import asyncio import inspect from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union @@ -120,6 +118,8 @@ class FlowMeta(type): methods = attr_value.__trigger_methods__ condition_type = getattr(attr_value, "__condition_type__", "OR") listeners[attr_name] = (condition_type, methods) + + # TODO: should we add a check for __condition_type__ 'AND'? elif hasattr(attr_value, "__is_router__"): routers[attr_value.__router_for__] = attr_name possible_returns = get_possible_return_constants(attr_value) @@ -159,7 +159,8 @@ class Flow(Generic[T], metaclass=FlowMeta): def __init__(self) -> None: self._methods: Dict[str, Callable] = {} self._state: T = self._create_initial_state() - self._completed_methods: Set[str] = set() + self._executed_methods: Set[str] = set() + self._scheduled_tasks: Set[str] = set() self._pending_and_listeners: Dict[str, Set[str]] = {} self._method_outputs: List[Any] = [] # List to store all method outputs @@ -216,17 +217,24 @@ class Flow(Generic[T], metaclass=FlowMeta): else: return None # Or raise an exception if no methods were executed - async def _execute_start_method(self, start_method: str) -> None: - result = await self._execute_method(self._methods[start_method]) - await self._execute_listeners(start_method, result) + async def _execute_start_method(self, start_method_name: str) -> None: + result = await self._execute_method( + start_method_name, self._methods[start_method_name] + ) + await self._execute_listeners(start_method_name, result) - async def _execute_method(self, method: Callable, *args: Any, **kwargs: Any) -> Any: + async def _execute_method( + self, method_name: str, method: Callable, *args: Any, **kwargs: Any + ) -> Any: result = ( await method(*args, **kwargs) if asyncio.iscoroutinefunction(method) else method(*args, **kwargs) ) self._method_outputs.append(result) # Store the output + + self._executed_methods.add(method_name) + return result async def _execute_listeners(self, trigger_method: str, result: Any) -> None: @@ -234,32 +242,40 @@ class Flow(Generic[T], metaclass=FlowMeta): if trigger_method in self._routers: router_method = self._methods[self._routers[trigger_method]] - path = await self._execute_method(router_method) + path = await self._execute_method( + trigger_method, router_method + ) # TODO: Change or not? # Use the path as the new trigger method trigger_method = path - for listener, (condition_type, methods) in self._listeners.items(): + for listener_name, (condition_type, methods) in self._listeners.items(): if condition_type == "OR": if trigger_method in methods: - listener_tasks.append( - self._execute_single_listener(listener, result) - ) + if ( + listener_name not in self._executed_methods + and listener_name not in self._scheduled_tasks + ): + self._scheduled_tasks.add(listener_name) + listener_tasks.append( + self._execute_single_listener(listener_name, result) + ) elif condition_type == "AND": - if listener not in self._pending_and_listeners: - self._pending_and_listeners[listener] = set() - self._pending_and_listeners[listener].add(trigger_method) - if set(methods) == self._pending_and_listeners[listener]: - listener_tasks.append( - self._execute_single_listener(listener, result) - ) - del self._pending_and_listeners[listener] + if all(method in self._executed_methods for method in methods): + if ( + listener_name not in self._executed_methods + and listener_name not in self._scheduled_tasks + ): + self._scheduled_tasks.add(listener_name) + listener_tasks.append( + self._execute_single_listener(listener_name, result) + ) # Run all listener tasks concurrently and wait for them to complete await asyncio.gather(*listener_tasks) - async def _execute_single_listener(self, listener: str, result: Any) -> None: + async def _execute_single_listener(self, listener_name: str, result: Any) -> None: try: - method = self._methods[listener] + method = self._methods[listener_name] sig = inspect.signature(method) params = list(sig.parameters.values()) @@ -268,15 +284,22 @@ class Flow(Generic[T], metaclass=FlowMeta): if method_params: # If listener expects parameters, pass the result - listener_result = await self._execute_method(method, result) + listener_result = await self._execute_method( + listener_name, method, result + ) else: # If listener does not expect parameters, call without arguments - listener_result = await self._execute_method(method) + listener_result = await self._execute_method(listener_name, method) + + # Remove from scheduled tasks after execution + self._scheduled_tasks.discard(listener_name) # Execute listeners of this listener - await self._execute_listeners(listener, listener_result) + await self._execute_listeners(listener_name, listener_result) except Exception as e: - print(f"[Flow._execute_single_listener] Error in method {listener}: {e}") + print( + f"[Flow._execute_single_listener] Error in method {listener_name}: {e}" + ) import traceback traceback.print_exc()