diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 8f85ac76a..4579cad6e 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,8 +1,11 @@ import asyncio import dataclasses +import functools import inspect import logging import threading +import time +from contextlib import contextmanager from typing import ( Any, Callable, @@ -10,12 +13,29 @@ from typing import ( Generic, List, Optional, + Protocol, Set, Type, TypeVar, Union, cast, ) + +from typing_extensions import Protocol + +logger = logging.getLogger(__name__) + + +class SerializationError(Exception): + """Error during state serialization.""" + pass + + +class LockProtocol(Protocol): + """Protocol for thread-safe primitives.""" + def acquire(self) -> bool: ... + def release(self) -> None: ... + def _is_owned(self) -> bool: ... from uuid import uuid4 from blinker import Signal @@ -438,6 +458,23 @@ class Flow(Generic[T], metaclass=FlowMeta): initial_state: Union[Type[T], T, None] = None event_emitter = Signal("event_emitter") + @contextmanager + def _performance_monitor(self, operation: str): + """Monitor performance of an operation. + + Args: + operation: Name of the operation being monitored + + Yields: + None + """ + start = time.perf_counter() + try: + yield + finally: + duration = time.perf_counter() - start + logger.debug(f"{operation} took {duration:.4f} seconds") + def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]: class _FlowGeneric(cls): # type: ignore _initial_state_T = item # type: ignore @@ -570,7 +607,20 @@ class Flow(Generic[T], metaclass=FlowMeta): f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" ) - def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type[Union[threading.Lock, threading.RLock, threading.Semaphore, threading.Event, threading.Condition, asyncio.Lock, asyncio.Event, asyncio.Condition, asyncio.Semaphore]]]: + # Cache thread-safe primitive types + THREAD_SAFE_TYPES = { + type(threading.RLock()): threading.RLock, + type(threading.Lock()): threading.Lock, + type(threading.Semaphore()): threading.Semaphore, + type(threading.Event()): threading.Event, + type(threading.Condition()): threading.Condition, + type(asyncio.Lock()): asyncio.Lock, + type(asyncio.Event()): asyncio.Event, + type(asyncio.Condition()): asyncio.Condition, + type(asyncio.Semaphore()): asyncio.Semaphore, + } + + def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type[LockProtocol]]: """Get the type of a thread-safe primitive for recreation. Args: @@ -579,37 +629,21 @@ class Flow(Generic[T], metaclass=FlowMeta): Returns: The type of the thread-safe primitive, or None if not a primitive """ - if hasattr(value, '_is_owned') and hasattr(value, 'acquire'): - # Get the actual types since some are factory functions - rlock_type = type(threading.RLock()) - lock_type = type(threading.Lock()) - semaphore_type = type(threading.Semaphore()) - event_type = type(threading.Event()) - condition_type = type(threading.Condition()) - async_lock_type = type(asyncio.Lock()) - async_event_type = type(asyncio.Event()) - async_condition_type = type(asyncio.Condition()) - async_semaphore_type = type(asyncio.Semaphore()) + return (self.THREAD_SAFE_TYPES.get(type(value)) + if hasattr(value, '_is_owned') and hasattr(value, 'acquire') + else None) - if isinstance(value, rlock_type): - return threading.RLock - elif isinstance(value, lock_type): - return threading.Lock - elif isinstance(value, semaphore_type): - return threading.Semaphore - elif isinstance(value, event_type): - return threading.Event - elif isinstance(value, condition_type): - return threading.Condition - elif isinstance(value, async_lock_type): - return asyncio.Lock - elif isinstance(value, async_event_type): - return asyncio.Event - elif isinstance(value, async_condition_type): - return asyncio.Condition - elif isinstance(value, async_semaphore_type): - return asyncio.Semaphore - return None + @functools.lru_cache(maxsize=128) + def _get_dataclass_fields(self, cls): + """Get cached dataclass fields. + + Args: + cls: Dataclass type + + Returns: + Dict mapping field names to Field objects + """ + return {field.name: field for field in dataclasses.fields(cls)} def _serialize_dataclass(self, value: Any) -> Union[Dict[str, Any], Any]: """Serialize a dataclass instance. @@ -620,24 +654,28 @@ class Flow(Generic[T], metaclass=FlowMeta): Returns: A new instance of the dataclass with thread-safe primitives recreated """ - if not hasattr(value, '__class__'): - return value - - if hasattr(value, '__pydantic_validate__'): - return value.__pydantic_validate__() - - # Get field values, handling thread-safe primitives - field_values = {} - for field in dataclasses.fields(value): - field_value = getattr(value, field.name) - primitive_type = self._get_thread_safe_primitive_type(field_value) - if primitive_type is not None: - field_values[field.name] = primitive_type() - else: - field_values[field.name] = self._serialize_value(field_value) + try: + if not hasattr(value, '__class__'): + return value - # Create new instance - return value.__class__(**field_values) + if hasattr(value, '__pydantic_validate__'): + return value.__pydantic_validate__() + + # Get field values, handling thread-safe primitives + field_values = {} + for field_name, field in self._get_dataclass_fields(value.__class__).items(): + field_value = getattr(value, field_name) + primitive_type = self._get_thread_safe_primitive_type(field_value) + if primitive_type is not None: + field_values[field_name] = primitive_type() + else: + field_values[field_name] = self._serialize_value(field_value) + + # Create new instance + return value.__class__(**field_values) + except Exception as e: + logger.error(f"Dataclass serialization error for {type(value)}: {str(e)}") + raise SerializationError(f"Failed to serialize dataclass {type(value)}") from e def _serialize_value(self, value: Any) -> Any: """Recursively serialize a value, handling thread locks. @@ -647,34 +685,42 @@ class Flow(Generic[T], metaclass=FlowMeta): Returns: Serialized version of the value with thread-safe primitives handled + + Raises: + SerializationError: If serialization fails """ - # Handle None - if value is None: - return None - - # Handle thread-safe primitives - primitive_type = self._get_thread_safe_primitive_type(value) - if primitive_type is not None: - return primitive_type() - - # Handle Pydantic models - if isinstance(value, BaseModel): - model_class = type(value) - model_data = value.model_dump(exclude_none=True) - - # Create new instance - instance = model_class(**model_data) - - # Copy excluded fields that are thread-safe primitives - for field_name, field in value.__class__.model_fields.items(): - if field.exclude: - field_value = getattr(value, field_name, None) - if field_value is not None: - primitive_type = self._get_thread_safe_primitive_type(field_value) - if primitive_type is not None: - setattr(instance, field_name, primitive_type()) - - return instance + with self._performance_monitor(f"serialize_{type(value).__name__}"): + try: + # Handle None + if value is None: + return None + + # Handle thread-safe primitives + primitive_type = self._get_thread_safe_primitive_type(value) + if primitive_type is not None: + return primitive_type() + + # Handle Pydantic models + if isinstance(value, BaseModel): + model_class = type(value) + model_data = value.model_dump(exclude_none=True) + + # Create new instance + instance = model_class(**model_data) + + # Copy excluded fields that are thread-safe primitives + for field_name, field in value.__class__.model_fields.items(): + if field.exclude: + field_value = getattr(value, field_name, None) + if field_value is not None: + primitive_type = self._get_thread_safe_primitive_type(field_value) + if primitive_type is not None: + setattr(instance, field_name, primitive_type()) + + return instance + except Exception as e: + logger.error(f"Serialization error for {type(value)}: {str(e)}") + raise SerializationError(f"Failed to serialize {type(value)}") from e # Handle dataclasses if dataclasses.is_dataclass(value): diff --git a/tests/test_flow_thread_locks.py b/tests/test_flow_thread_locks.py index 974d6ac7e..12717ffad 100644 --- a/tests/test_flow_thread_locks.py +++ b/tests/test_flow_thread_locks.py @@ -157,3 +157,26 @@ def test_flow_with_async_locks(): result = asyncio.run(flow.kickoff_async()) assert result == "step 1 -> step 2" assert flow.state.value == "step 1 -> step 2" + + +def test_flow_concurrent_access(): + """Test Flow with concurrent access.""" + flow = LockFlow() + results = [] + errors = [] + + async def run_flow(): + try: + result = await flow.kickoff_async() + results.append(result) + except Exception as e: + errors.append(e) + + async def test(): + tasks = [run_flow() for _ in range(10)] + await asyncio.gather(*tasks) + + asyncio.run(test()) + assert len(results) == 10 + assert not errors + assert all(result == "step 1 -> step 2" for result in results)