diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 0feb67def..ead6322cc 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -6,6 +6,7 @@ import logging import threading from typing import ( Any, + Awaitable, Callable, Dict, Generic, @@ -572,7 +573,7 @@ class Flow(Generic[T], metaclass=FlowMeta): f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" ) - def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type]: + def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type[Union[threading.Lock, threading.RLock, threading.Semaphore, threading.Event, threading.Condition, asyncio.Lock, asyncio.Event, asyncio.Condition, asyncio.Semaphore]]]: """Get the type of a thread-safe primitive for recreation. Args: @@ -602,7 +603,7 @@ class Flow(Generic[T], metaclass=FlowMeta): return asyncio.Semaphore return None - def _serialize_dataclass(self, value: Any) -> Any: + def _serialize_dataclass(self, value: Any) -> Union[Dict[str, Any], Any]: """Serialize a dataclass instance. Args: @@ -685,36 +686,6 @@ class Flow(Generic[T], metaclass=FlowMeta): # Handle other types return value - - def _serialize_value(self, value: Any) -> Any: - """Recursively serialize a value, handling nested objects and locks. - - Args: - value: Any Python value to serialize - - Returns: - Serialized version of the value with locks properly handled - """ - if isinstance(value, BaseModel): - return type(value)(**{ - k: self._serialize_value(v) - for k, v in value.model_dump().items() - }) - elif isinstance(value, dict): - return { - k: self._serialize_value(v) - for k, v in value.items() - } - elif isinstance(value, list): - return [self._serialize_value(item) for item in value] - elif isinstance(value, tuple): - return tuple(self._serialize_value(item) for item in value) - elif isinstance(value, set): - return {self._serialize_value(item) for item in value} - elif hasattr(value, '_is_owned') and hasattr(value, 'acquire'): - # Skip thread locks and similar synchronization primitives - return None - return value def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]: """Serialize the current state for event emission. @@ -724,7 +695,7 @@ class Flow(Generic[T], metaclass=FlowMeta): when state hasn't changed. Handles nested objects and locks recursively. Returns: - Serialized state as either a new BaseModel instance or dictionary + Union[Dict[str, Any], BaseModel]: Serialized state as either a new BaseModel instance or dictionary Raises: ValueError: If state has invalid type @@ -749,7 +720,7 @@ class Flow(Generic[T], metaclass=FlowMeta): return serialized except Exception as e: logger.error(f"State serialization failed: {str(e)}") - return {} + return cast(Dict[str, Any], {}) @property def state(self) -> T: @@ -881,7 +852,7 @@ class Flow(Generic[T], metaclass=FlowMeta): else: raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}") - def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: + def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Union[Any, None]: """Start the flow execution. Args: @@ -984,7 +955,7 @@ class Flow(Generic[T], metaclass=FlowMeta): await self._execute_listeners(start_method_name, result) async def _execute_method( - self, method_name: str, method: Callable, *args: Any, **kwargs: Any + self, method_name: str, method: Union[Callable[..., Any], Callable[..., Awaitable[Any]]], *args: Any, **kwargs: Any ) -> Any: # Serialize state before event emission to avoid pickling issues state_copy = self._serialize_state() @@ -1077,7 +1048,7 @@ class Flow(Generic[T], metaclass=FlowMeta): await asyncio.gather(*tasks) def _find_triggered_methods( - self, trigger_method: str, router_only: bool + self, trigger_method: str, router_only: bool = False ) -> List[str]: """ Finds all methods that should be triggered based on conditions. @@ -1186,7 +1157,7 @@ class Flow(Generic[T], metaclass=FlowMeta): traceback.print_exc() def _log_flow_event( - self, message: str, color: str = "yellow", level: str = "info" + self, message: str, color: Optional[str] = "yellow", level: Optional[str] = "info" ) -> None: """Centralized logging method for flow events. @@ -1211,7 +1182,7 @@ class Flow(Generic[T], metaclass=FlowMeta): elif level == "warning": logger.warning(message) - def plot(self, filename: str = "crewai_flow") -> None: + def plot(self, filename: Optional[str] = "crewai_flow") -> None: self._telemetry.flow_plotting_span( self.__class__.__name__, list(self._methods.keys()) )