diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 5f4953de7..34e4955fc 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -149,15 +149,16 @@ class Flow(Generic[T], metaclass=FlowMeta): _router_paths: Dict[str, List[str]] = {} initial_state: Union[Type[T], T, None] = None - def __class_getitem__(cls, item): + def __class_getitem__(cls, item: Type[T]) -> Type["Flow"]: class _FlowGeneric(cls): - _initial_state_T = item + _initial_state_T: Type[T] = item + _FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]" return _FlowGeneric - def __init__(self): + def __init__(self) -> None: self._methods: Dict[str, Callable] = {} - self._state = self._create_initial_state() + self._state: T = self._create_initial_state() self._completed_methods: Set[str] = set() self._pending_and_listeners: Dict[str, Set[str]] = {} self._method_outputs: List[Any] = [] # List to store all method outputs @@ -212,11 +213,11 @@ class Flow(Generic[T], metaclass=FlowMeta): else: return None # Or raise an exception if no methods were executed - async def _execute_start_method(self, start_method: str): + async def _execute_start_method(self, start_method: str) -> None: result = await self._execute_method(self._methods[start_method]) await self._execute_listeners(start_method, result) - async def _execute_method(self, method: Callable, *args, **kwargs): + async def _execute_method(self, method: Callable, *args: Any, **kwargs: Any) -> Any: result = ( await method(*args, **kwargs) if asyncio.iscoroutinefunction(method) @@ -225,7 +226,7 @@ class Flow(Generic[T], metaclass=FlowMeta): self._method_outputs.append(result) # Store the output return result - async def _execute_listeners(self, trigger_method: str, result: Any): + async def _execute_listeners(self, trigger_method: str, result: Any) -> None: listener_tasks = [] if trigger_method in self._routers: @@ -253,7 +254,7 @@ class Flow(Generic[T], metaclass=FlowMeta): # Run all listener tasks concurrently and wait for them to complete await asyncio.gather(*listener_tasks) - async def _execute_single_listener(self, listener: str, result: Any): + async def _execute_single_listener(self, listener: str, result: Any) -> None: try: method = self._methods[listener] sig = inspect.signature(method) @@ -277,7 +278,7 @@ class Flow(Generic[T], metaclass=FlowMeta): traceback.print_exc() - def plot(self, filename: str = "crewai_flow"): + def plot(self, filename: str = "crewai_flow") -> None: self._telemetry.flow_plotting_span( self.__class__.__name__, list(self._methods.keys()) )