diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 4e9e43162..e5d14a793 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -570,7 +570,51 @@ class Flow(Generic[T], metaclass=FlowMeta): ) def _copy_state(self) -> T: + """Create a deep copy of the current state. + + Returns: + A deep copy of the current state object + """ return copy.deepcopy(self._state) + + def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]: + """Serialize the current state for event emission. + + This method handles the serialization of both BaseModel and dictionary states, + ensuring thread-safe copying of state data. Uses caching to improve performance + when state hasn't changed. + + Returns: + Serialized state as either a new BaseModel instance or dictionary + + Raises: + ValueError: If state has invalid type + Exception: If serialization fails, logs error and returns empty dict + """ + try: + if not isinstance(self._state, (dict, BaseModel)): + raise ValueError(f"Invalid state type: {type(self._state)}") + + if not hasattr(self, '_last_state_hash'): + self._last_state_hash = None + self._last_serialized_state = None + + current_hash = hash(str(self._state)) + if current_hash == self._last_state_hash: + return self._last_serialized_state + + serialized = ( + type(self._state)(**self._state.model_dump()) + if isinstance(self._state, BaseModel) + else dict(self._state) + ) + + self._last_state_hash = current_hash + self._last_serialized_state = serialized + return serialized + except Exception as e: + logger.error(f"State serialization failed: {str(e)}") + return {} @property def state(self) -> T: @@ -808,11 +852,7 @@ class Flow(Generic[T], metaclass=FlowMeta): self, method_name: str, method: Callable, *args: Any, **kwargs: Any ) -> Any: # Serialize state before event emission to avoid pickling issues - state_copy = ( - type(self._state)(**self._state.model_dump()) - if isinstance(self._state, BaseModel) - else dict(self._state) - ) + state_copy = self._serialize_state() dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {}) self.event_emitter.send( @@ -837,11 +877,7 @@ class Flow(Generic[T], metaclass=FlowMeta): ) # Serialize state after execution - state_copy = ( - type(self._state)(**self._state.model_dump()) - if isinstance(self._state, BaseModel) - else dict(self._state) - ) + state_copy = self._serialize_state() self.event_emitter.send( self, diff --git a/src/crewai/flow/flow_events.py b/src/crewai/flow/flow_events.py index 27746e9c8..09053d341 100644 --- a/src/crewai/flow/flow_events.py +++ b/src/crewai/flow/flow_events.py @@ -15,35 +15,62 @@ class Event: self.timestamp = datetime.now() +@dataclass +class BaseStateEvent(Event): + """Base class for events containing state data. + + Handles common state serialization and validation logic to ensure thread-safe + state handling and proper type validation. + + Raises: + ValueError: If state has invalid type + """ + state: Union[Dict[str, Any], BaseModel] + + def __post_init__(self): + super().__post_init__() + self._process_state() + + def _process_state(self): + """Process and validate state data. + + Ensures state is of valid type and creates a new instance of BaseModel + states to avoid thread lock serialization issues. + + Raises: + ValueError: If state has invalid type + """ + if not isinstance(self.state, (dict, BaseModel)): + raise ValueError(f"Invalid state type: {type(self.state)}") + if isinstance(self.state, BaseModel): + self.state = type(self.state)(**self.state.model_dump()) + + @dataclass class FlowStartedEvent(Event): inputs: Optional[Dict[str, Any]] = None @dataclass -class MethodExecutionStartedEvent(Event): +class MethodExecutionStartedEvent(BaseStateEvent): method_name: str state: Union[Dict[str, Any], BaseModel] params: Optional[Dict[str, Any]] = None def __post_init__(self): super().__post_init__() - # Create a new instance of BaseModel state to avoid pickling issues - if isinstance(self.state, BaseModel): - self.state = type(self.state)(**self.state.model_dump()) + self._process_state() @dataclass -class MethodExecutionFinishedEvent(Event): +class MethodExecutionFinishedEvent(BaseStateEvent): method_name: str state: Union[Dict[str, Any], BaseModel] result: Any = None def __post_init__(self): super().__post_init__() - # Create a new instance of BaseModel state to avoid pickling issues - if isinstance(self.state, BaseModel): - self.state = type(self.state)(**self.state.model_dump()) + self._process_state() @dataclass diff --git a/tests/flow_test.py b/tests/flow_test.py index fa324fab7..e956eadd0 100644 --- a/tests/flow_test.py +++ b/tests/flow_test.py @@ -377,6 +377,72 @@ def test_flow_with_thread_lock(): assert flow.counter == 2 +def test_flow_with_nested_locks(): + """Test that Flow properly handles nested thread locks.""" + import threading + + class NestedLockFlow(Flow): + def __init__(self): + super().__init__() + self.outer_lock = threading.RLock() + self.inner_lock = threading.RLock() + self.counter = 0 + + @start() + async def step_1(self): + with self.outer_lock: + with self.inner_lock: + self.counter += 1 + return "step 1" + + @listen(step_1) + async def step_2(self, result): + with self.outer_lock: + with self.inner_lock: + self.counter += 1 + return result + " -> step 2" + + flow = NestedLockFlow() + result = flow.kickoff() + + assert result == "step 1 -> step 2" + assert flow.counter == 2 + + +@pytest.mark.asyncio +async def test_flow_with_async_locks(): + """Test that Flow properly handles locks in async context.""" + import asyncio + import threading + + class AsyncLockFlow(Flow): + def __init__(self): + super().__init__() + self.lock = threading.RLock() + self.async_lock = asyncio.Lock() + self.counter = 0 + + @start() + async def step_1(self): + async with self.async_lock: + with self.lock: + self.counter += 1 + return "step 1" + + @listen(step_1) + async def step_2(self, result): + async with self.async_lock: + with self.lock: + self.counter += 1 + return result + " -> step 2" + + flow = AsyncLockFlow() + result = await flow.kickoff_async() + + assert result == "step 1 -> step 2" + assert flow.counter == 2 + + def test_router_with_multiple_conditions(): """Test a router that triggers when any of multiple steps complete (OR condition), and another router that triggers only after all specified steps complete (AND condition).