diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index e5d14a793..0feb67def 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,7 +1,9 @@ import asyncio import copy +import dataclasses import inspect import logging +import threading from typing import ( Any, Callable, @@ -562,27 +564,164 @@ class Flow(Generic[T], metaclass=FlowMeta): k: v for k, v in model.__dict__.items() if not k.startswith("_") } - # Create new instance of the same class + # Create new instance of the same class, handling thread locks model_class = type(model) - return cast(T, model_class(**state_dict)) + serialized_dict = self._serialize_value(state_dict) + return cast(T, model_class(**serialized_dict)) raise TypeError( f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" ) + def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type]: + """Get the type of a thread-safe primitive for recreation. + + Args: + value: Any Python value to check + + Returns: + The type of the thread-safe primitive, or None if not a primitive + """ + if hasattr(value, '_is_owned') and hasattr(value, 'acquire'): + if isinstance(value, threading.RLock): + return threading.RLock + elif isinstance(value, threading.Lock): + return threading.Lock + elif isinstance(value, threading.Semaphore): + return threading.Semaphore + elif isinstance(value, threading.Event): + return threading.Event + elif isinstance(value, threading.Condition): + return threading.Condition + elif isinstance(value, asyncio.Lock): + return asyncio.Lock + elif isinstance(value, asyncio.Event): + return asyncio.Event + elif isinstance(value, asyncio.Condition): + return asyncio.Condition + elif isinstance(value, asyncio.Semaphore): + return asyncio.Semaphore + return None + + def _serialize_dataclass(self, value: Any) -> Any: + """Serialize a dataclass instance. + + Args: + value: A dataclass instance + + Returns: + A new instance of the dataclass with thread-safe primitives recreated + """ + if not hasattr(value, '__class__'): + return value + + if hasattr(value, '__pydantic_validate__'): + return value.__pydantic_validate__() + + # Get field values, handling thread-safe primitives + field_values = {} + for field in dataclasses.fields(value): + field_value = getattr(value, field.name) + primitive_type = self._get_thread_safe_primitive_type(field_value) + if primitive_type is not None: + field_values[field.name] = primitive_type() + else: + field_values[field.name] = self._serialize_value(field_value) + + # Create new instance + return value.__class__(**field_values) + def _copy_state(self) -> T: """Create a deep copy of the current state. Returns: A deep copy of the current state object """ - return copy.deepcopy(self._state) + return self._serialize_value(self._state) + + def _serialize_value(self, value: Any) -> Any: + """Recursively serialize a value, handling nested objects and locks. + Args: + value: Any Python value to serialize + + Returns: + Serialized version of the value with thread-safe primitives handled + """ + # Handle None + if value is None: + return None + + # Handle thread-safe primitives + primitive_type = self._get_thread_safe_primitive_type(value) + if primitive_type is not None: + return None + + # Handle Pydantic models + if isinstance(value, BaseModel): + return type(value)(**{ + k: self._serialize_value(v) + for k, v in value.model_dump().items() + }) + + # Handle dataclasses + if dataclasses.is_dataclass(value): + return self._serialize_dataclass(value) + + # Handle dictionaries + if isinstance(value, dict): + return { + k: self._serialize_value(v) + for k, v in value.items() + } + + # Handle lists, tuples, and sets + if isinstance(value, (list, tuple, set)): + serialized = [self._serialize_value(item) for item in value] + return ( + serialized if isinstance(value, list) + else tuple(serialized) if isinstance(value, tuple) + else set(serialized) + ) + + # Handle other types + return value + + def _serialize_value(self, value: Any) -> Any: + """Recursively serialize a value, handling nested objects and locks. + + Args: + value: Any Python value to serialize + + Returns: + Serialized version of the value with locks properly handled + """ + if isinstance(value, BaseModel): + return type(value)(**{ + k: self._serialize_value(v) + for k, v in value.model_dump().items() + }) + elif isinstance(value, dict): + return { + k: self._serialize_value(v) + for k, v in value.items() + } + elif isinstance(value, list): + return [self._serialize_value(item) for item in value] + elif isinstance(value, tuple): + return tuple(self._serialize_value(item) for item in value) + elif isinstance(value, set): + return {self._serialize_value(item) for item in value} + elif hasattr(value, '_is_owned') and hasattr(value, 'acquire'): + # Skip thread locks and similar synchronization primitives + return None + return value + def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]: """Serialize the current state for event emission. This method handles the serialization of both BaseModel and dictionary states, ensuring thread-safe copying of state data. Uses caching to improve performance - when state hasn't changed. + when state hasn't changed. Handles nested objects and locks recursively. Returns: Serialized state as either a new BaseModel instance or dictionary @@ -603,11 +742,7 @@ class Flow(Generic[T], metaclass=FlowMeta): if current_hash == self._last_state_hash: return self._last_serialized_state - serialized = ( - type(self._state)(**self._state.model_dump()) - if isinstance(self._state, BaseModel) - else dict(self._state) - ) + serialized = self._serialize_value(self._state) self._last_state_hash = current_hash self._last_serialized_state = serialized diff --git a/tests/flow_test.py b/tests/flow_test.py index e956eadd0..b1ef6ba4b 100644 --- a/tests/flow_test.py +++ b/tests/flow_test.py @@ -4,7 +4,8 @@ import asyncio from datetime import datetime import pytest -from pydantic import BaseModel +from uuid import uuid4 +from pydantic import BaseModel, Field from crewai.flow.flow import Flow, and_, listen, or_, router, start from crewai.flow.flow_events import ( @@ -377,6 +378,80 @@ def test_flow_with_thread_lock(): assert flow.counter == 2 +def test_flow_with_nested_objects_and_locks(): + """Test that Flow properly handles nested objects containing locks.""" + import threading + from dataclasses import dataclass + from typing import Dict, List, Optional + + @dataclass + class NestedState: + value: str + lock: threading.RLock = None + + def __post_init__(self): + if self.lock is None: + self.lock = threading.RLock() + + def __pydantic_validate__(self): + return {"value": self.value, "lock": threading.RLock()} + + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + from pydantic_core.core_schema import ( + with_info_plain_validator_function, + str_schema, + ) + def validate(value, _): + if isinstance(value, cls): + return value + if isinstance(value, dict): + return cls(value["value"]) + raise ValueError(f"Invalid value type for {cls.__name__}") + return with_info_plain_validator_function(validate) + + class ComplexState(BaseModel): + id: str = Field(default_factory=lambda: str(uuid4())) + name: str + nested: NestedState + items: List[NestedState] + mapping: Dict[str, NestedState] + optional: Optional[NestedState] = None + + class ComplexStateFlow(Flow[ComplexState]): + def __init__(self): + self.initial_state = ComplexState( + name="test", + nested=NestedState("nested", threading.RLock()), + items=[ + NestedState("item1", threading.RLock()), + NestedState("item2", threading.RLock()) + ], + mapping={ + "key1": NestedState("map1", threading.RLock()), + "key2": NestedState("map2", threading.RLock()) + }, + optional=NestedState("optional", threading.RLock()) + ) + super().__init__() + + @start() + async def step_1(self): + with self.state.nested.lock: + return "step 1" + + @listen(step_1) + async def step_2(self, result): + with self.state.items[0].lock: + with self.state.mapping["key1"].lock: + with self.state.optional.lock: + return result + " -> step 2" + + flow = ComplexStateFlow() + result = flow.kickoff() + + assert result == "step 1 -> step 2" + def test_flow_with_nested_locks(): """Test that Flow properly handles nested thread locks.""" import threading