diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index f1242a2bf..8f85ac76a 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,7 +1,8 @@ import asyncio -import copy +import dataclasses import inspect import logging +import threading from typing import ( Any, Callable, @@ -569,8 +570,138 @@ class Flow(Generic[T], metaclass=FlowMeta): f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" ) + def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type[Union[threading.Lock, threading.RLock, threading.Semaphore, threading.Event, threading.Condition, asyncio.Lock, asyncio.Event, asyncio.Condition, asyncio.Semaphore]]]: + """Get the type of a thread-safe primitive for recreation. + + Args: + value: Any Python value to check + + Returns: + The type of the thread-safe primitive, or None if not a primitive + """ + if hasattr(value, '_is_owned') and hasattr(value, 'acquire'): + # Get the actual types since some are factory functions + rlock_type = type(threading.RLock()) + lock_type = type(threading.Lock()) + semaphore_type = type(threading.Semaphore()) + event_type = type(threading.Event()) + condition_type = type(threading.Condition()) + async_lock_type = type(asyncio.Lock()) + async_event_type = type(asyncio.Event()) + async_condition_type = type(asyncio.Condition()) + async_semaphore_type = type(asyncio.Semaphore()) + + if isinstance(value, rlock_type): + return threading.RLock + elif isinstance(value, lock_type): + return threading.Lock + elif isinstance(value, semaphore_type): + return threading.Semaphore + elif isinstance(value, event_type): + return threading.Event + elif isinstance(value, condition_type): + return threading.Condition + elif isinstance(value, async_lock_type): + return asyncio.Lock + elif isinstance(value, async_event_type): + return asyncio.Event + elif isinstance(value, async_condition_type): + return asyncio.Condition + elif isinstance(value, async_semaphore_type): + return asyncio.Semaphore + return None + + def _serialize_dataclass(self, value: Any) -> Union[Dict[str, Any], Any]: + """Serialize a dataclass instance. + + Args: + value: A dataclass instance + + Returns: + A new instance of the dataclass with thread-safe primitives recreated + """ + if not hasattr(value, '__class__'): + return value + + if hasattr(value, '__pydantic_validate__'): + return value.__pydantic_validate__() + + # Get field values, handling thread-safe primitives + field_values = {} + for field in dataclasses.fields(value): + field_value = getattr(value, field.name) + primitive_type = self._get_thread_safe_primitive_type(field_value) + if primitive_type is not None: + field_values[field.name] = primitive_type() + else: + field_values[field.name] = self._serialize_value(field_value) + + # Create new instance + return value.__class__(**field_values) + + def _serialize_value(self, value: Any) -> Any: + """Recursively serialize a value, handling thread locks. + + Args: + value: Any Python value to serialize + + Returns: + Serialized version of the value with thread-safe primitives handled + """ + # Handle None + if value is None: + return None + + # Handle thread-safe primitives + primitive_type = self._get_thread_safe_primitive_type(value) + if primitive_type is not None: + return primitive_type() + + # Handle Pydantic models + if isinstance(value, BaseModel): + model_class = type(value) + model_data = value.model_dump(exclude_none=True) + + # Create new instance + instance = model_class(**model_data) + + # Copy excluded fields that are thread-safe primitives + for field_name, field in value.__class__.model_fields.items(): + if field.exclude: + field_value = getattr(value, field_name, None) + if field_value is not None: + primitive_type = self._get_thread_safe_primitive_type(field_value) + if primitive_type is not None: + setattr(instance, field_name, primitive_type()) + + return instance + + # Handle dataclasses + if dataclasses.is_dataclass(value): + return self._serialize_dataclass(value) + + # Handle dictionaries + if isinstance(value, dict): + return { + k: self._serialize_value(v) + for k, v in value.items() + } + + # Handle lists, tuples, and sets + if isinstance(value, (list, tuple, set)): + serialized = [self._serialize_value(item) for item in value] + return ( + serialized if isinstance(value, list) + else tuple(serialized) if isinstance(value, tuple) + else set(serialized) + ) + + # Handle other types + return value + def _copy_state(self) -> T: - return copy.deepcopy(self._state) + """Create a deep copy of the current state.""" + return self._serialize_value(self._state) @property def state(self) -> T: diff --git a/tests/test_flow_thread_locks.py b/tests/test_flow_thread_locks.py new file mode 100644 index 000000000..8c252cf94 --- /dev/null +++ b/tests/test_flow_thread_locks.py @@ -0,0 +1,159 @@ +"""Tests for Flow with thread locks.""" +import asyncio +import threading +from typing import Optional +from uuid import uuid4 + +import pytest +from pydantic import BaseModel, Field, field_validator + +from crewai.flow.flow import Flow, start, listen + + +class ThreadSafeState(BaseModel): + """Test state model with thread locks.""" + model_config = { + "arbitrary_types_allowed": True, + "exclude": {"lock"} + } + + id: str = Field(default_factory=lambda: str(uuid4())) + lock: Optional[threading.RLock] = Field(default=None, exclude=True) + value: str = "" + + def __init__(self, **data): + super().__init__(**data) + if self.lock is None: + self.lock = threading.RLock() + + +class LockFlow(Flow[ThreadSafeState]): + """Test flow with thread locks.""" + initial_state = ThreadSafeState + + @start() + async def step_1(self): + with self.state.lock: + self.state.value = "step 1" + return "step 1" + + @listen(step_1) + async def step_2(self, result): + with self.state.lock: + self.state.value += " -> step 2" + return result + " -> step 2" + + +def test_flow_with_thread_locks(): + """Test Flow with thread locks in state.""" + flow = LockFlow() + result = asyncio.run(flow.kickoff_async()) + assert result == "step 1 -> step 2" + assert flow.state.value == "step 1 -> step 2" + + +def test_kickoff_async_with_lock_inputs(): + """Test kickoff_async with thread lock inputs.""" + flow = LockFlow() + inputs = { + "lock": threading.RLock(), + "value": "test" + } + result = asyncio.run(flow.kickoff_async(inputs=inputs)) + assert result == "step 1 -> step 2" + assert flow.state.value == "step 1 -> step 2" + + +class ComplexState(BaseModel): + """Test state model with nested thread locks.""" + model_config = { + "arbitrary_types_allowed": True, + "exclude": {"outer_lock"} + } + + id: str = Field(default_factory=lambda: str(uuid4())) + outer_lock: Optional[threading.RLock] = Field(default=None, exclude=True) + inner: Optional[ThreadSafeState] = Field(default_factory=ThreadSafeState) + value: str = "" + + def __init__(self, **data): + super().__init__(**data) + if self.outer_lock is None: + self.outer_lock = threading.RLock() + + +class NestedLockFlow(Flow[ComplexState]): + """Test flow with nested thread locks.""" + initial_state = ComplexState + + @start() + async def step_1(self): + with self.state.outer_lock: + with self.state.inner.lock: + self.state.value = "outer" + self.state.inner.value = "inner" + return "step 1" + + @listen(step_1) + async def step_2(self, result): + with self.state.outer_lock: + with self.state.inner.lock: + self.state.value += " -> outer 2" + self.state.inner.value += " -> inner 2" + return result + " -> step 2" + + +def test_flow_with_nested_locks(): + """Test Flow with nested thread locks in state.""" + flow = NestedLockFlow() + result = asyncio.run(flow.kickoff_async()) + assert result == "step 1 -> step 2" + assert flow.state.value == "outer -> outer 2" + assert flow.state.inner.value == "inner -> inner 2" + + +class AsyncLockState(BaseModel): + """Test state model with async locks.""" + model_config = { + "arbitrary_types_allowed": True, + "exclude": {"lock", "event"} + } + + id: str = Field(default_factory=lambda: str(uuid4())) + lock: Optional[asyncio.Lock] = Field(default=None, exclude=True) + event: Optional[asyncio.Event] = Field(default=None, exclude=True) + value: str = "" + + def __init__(self, **data): + super().__init__(**data) + if self.lock is None: + self.lock = asyncio.Lock() + if self.event is None: + self.event = asyncio.Event() + + +class AsyncLockFlow(Flow[AsyncLockState]): + """Test flow with async locks.""" + initial_state = AsyncLockState + + @start() + async def step_1(self): + async with self.state.lock: + self.state.value = "step 1" + self.state.event.set() + return "step 1" + + @listen(step_1) + async def step_2(self, result): + async with self.state.lock: + await self.state.event.wait() + self.state.value += " -> step 2" + return result + " -> step 2" + + +def test_flow_with_async_locks(): + """Test Flow with async locks in state.""" + flow = AsyncLockFlow() + result = asyncio.run(flow.kickoff_async()) + assert result == "step 1 -> step 2" + assert flow.state.value == "step 1 -> step 2"