diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 5d2df2988..4eed89cc2 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -6,6 +6,7 @@ import logging import threading from typing import ( Any, + Awaitable, Callable, Dict, Generic, @@ -582,7 +583,7 @@ class Flow(Generic[T], metaclass=FlowMeta): f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" ) - def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type]: + def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type[Union[threading.Lock, threading.RLock, threading.Semaphore, threading.Event, threading.Condition, asyncio.Lock, asyncio.Event, asyncio.Condition, asyncio.Semaphore]]]: """Get the type of a thread-safe primitive for recreation. Args: @@ -612,7 +613,7 @@ class Flow(Generic[T], metaclass=FlowMeta): return asyncio.Semaphore return None - def _serialize_dataclass(self, value: Any) -> Any: + def _serialize_dataclass(self, value: Any) -> Union[Dict[str, Any], Any]: """Serialize a dataclass instance. Args: @@ -695,36 +696,6 @@ class Flow(Generic[T], metaclass=FlowMeta): # Handle other types return value - - def _serialize_value(self, value: Any) -> Any: - """Recursively serialize a value, handling nested objects and locks. - - Args: - value: Any Python value to serialize - - Returns: - Serialized version of the value with locks properly handled - """ - if isinstance(value, BaseModel): - return type(value)(**{ - k: self._serialize_value(v) - for k, v in value.model_dump().items() - }) - elif isinstance(value, dict): - return { - k: self._serialize_value(v) - for k, v in value.items() - } - elif isinstance(value, list): - return [self._serialize_value(item) for item in value] - elif isinstance(value, tuple): - return tuple(self._serialize_value(item) for item in value) - elif isinstance(value, set): - return {self._serialize_value(item) for item in value} - elif hasattr(value, '_is_owned') and hasattr(value, 'acquire'): - # Skip thread locks and similar synchronization primitives - return None - return value def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]: """Serialize the current state for event emission. @@ -734,7 +705,7 @@ class Flow(Generic[T], metaclass=FlowMeta): when state hasn't changed. Handles nested objects and locks recursively. Returns: - Serialized state as either a new BaseModel instance or dictionary + Union[Dict[str, Any], BaseModel]: Serialized state as either a new BaseModel instance or dictionary Raises: ValueError: If state has invalid type @@ -759,7 +730,7 @@ class Flow(Generic[T], metaclass=FlowMeta): return serialized except Exception as e: logger.error(f"State serialization failed: {str(e)}") - return {} + return cast(Dict[str, Any], {}) @property def state(self) -> T: @@ -891,7 +862,7 @@ class Flow(Generic[T], metaclass=FlowMeta): else: raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}") - def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any: + def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Union[Any, None]: """Start the flow execution. Args: @@ -995,8 +966,22 @@ class Flow(Generic[T], metaclass=FlowMeta): @trace_flow_step async def _execute_method( - self, method_name: str, method: Callable, *args: Any, **kwargs: Any + self, method_name: str, method: Union[Callable[..., Any], Callable[..., Awaitable[Any]]], *args: Any, **kwargs: Any ) -> Any: + """Execute a flow method with proper event handling and state management. + + Args: + method_name: Name of the method to execute + method: The method to execute + *args: Positional arguments for the method + **kwargs: Keyword arguments for the method + + Returns: + The result of the method execution + + Raises: + Any exception that occurs during method execution + """ try: # Serialize state before event emission to avoid pickling issues state_copy = self._serialize_state() @@ -1051,87 +1036,6 @@ class Flow(Generic[T], metaclass=FlowMeta): ) raise e - dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {}) - crewai_event_bus.emit( - self, - MethodExecutionStartedEvent( - type="method_execution_started", - method_name=method_name, - flow_name=self.__class__.__name__, - params=dumped_params, - state=state_copy, - ), - ) - -<<<<<<< HEAD - result = ( - await method(*args, **kwargs) - if asyncio.iscoroutinefunction(method) - else method(*args, **kwargs) - ) -||||||| parent of ed877467 (refactor: Improve Flow state serialization) - # Serialize state after execution - state_copy = ( - type(self._state)(**self._state.model_dump()) - if isinstance(self._state, BaseModel) - else dict(self._state) - ) - - self.event_emitter.send( - self, - event=MethodExecutionFinishedEvent( - type="method_execution_finished", - method_name=method_name, - flow_name=self.__class__.__name__, - state=state_copy, - result=result, - ), - ) -======= - # Serialize state after execution - state_copy = self._serialize_state() - - self.event_emitter.send( - self, - event=MethodExecutionFinishedEvent( - type="method_execution_finished", - method_name=method_name, - flow_name=self.__class__.__name__, - state=state_copy, - result=result, - ), - ) ->>>>>>> ed877467 (refactor: Improve Flow state serialization) - - self._method_outputs.append(result) - self._method_execution_counts[method_name] = ( - self._method_execution_counts.get(method_name, 0) + 1 - ) - - crewai_event_bus.emit( - self, - MethodExecutionFinishedEvent( - type="method_execution_finished", - method_name=method_name, - flow_name=self.__class__.__name__, - state=state_copy, - result=result, - ), - ) - - return result - except Exception as e: - crewai_event_bus.emit( - self, - MethodExecutionFailedEvent( - type="method_execution_failed", - method_name=method_name, - flow_name=self.__class__.__name__, - error=e, - ), - ) - raise e - async def _execute_listeners(self, trigger_method: str, result: Any) -> None: """ Executes all listeners and routers triggered by a method completion. @@ -1182,7 +1086,7 @@ class Flow(Generic[T], metaclass=FlowMeta): await asyncio.gather(*tasks) def _find_triggered_methods( - self, trigger_method: str, router_only: bool + self, trigger_method: str, router_only: bool = False ) -> List[str]: """ Finds all methods that should be triggered based on conditions. @@ -1292,7 +1196,7 @@ class Flow(Generic[T], metaclass=FlowMeta): traceback.print_exc() def _log_flow_event( - self, message: str, color: str = "yellow", level: str = "info" + self, message: str, color: Optional[str] = "yellow", level: Optional[str] = "info" ) -> None: """Centralized logging method for flow events. @@ -1317,7 +1221,13 @@ class Flow(Generic[T], metaclass=FlowMeta): elif level == "warning": logger.warning(message) - def plot(self, filename: str = "crewai_flow") -> None: +<<<<<<< HEAD + def plot(self, filename: Optional[str] = "crewai_flow") -> None: + """Plot the flow graph visualization. + + Args: + filename: Optional name for the output file (default: "crewai_flow") + """ crewai_event_bus.emit( self, FlowPlotEvent( @@ -1325,4 +1235,7 @@ class Flow(Generic[T], metaclass=FlowMeta): flow_name=self.__class__.__name__, ), ) + self._telemetry.flow_plotting_span( + self.__class__.__name__, list(self._methods.keys()) + ) plot_flow(self, filename)