Compare commits

...

17 Commits

Author SHA1 Message Date
Devin AI
f0d0511b24 chore: Merge remote changes and remove telemetry dependency
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:26:18 +00:00
Devin AI
9fad174b74 fix: Remove telemetry dependency from Flow plot method
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:25:10 +00:00
Devin AI
dfba35e475 fix: Fix thread lock type checking in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:18:20 +00:00
Devin AI
2e01d1029b fix: Fix type error in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:17:58 +00:00
Devin AI
84f770aa5d style: Fix import sorting in tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:54 +00:00
Devin AI
2a2c163c3d test: Add comprehensive test for complex nested objects
- Add test for various thread-safe primitives
- Test nested dataclasses with complex state
- Verify serialization of async primitives

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:54 +00:00
Devin AI
3348de8db7 refactor: Improve Flow state serialization with Pydantic core schema
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:54 +00:00
Devin AI
93ec41225b refactor: Improve Flow state serialization
- Add BaseStateEvent class for common state processing
- Add state serialization caching for performance
- Add tests for nested locks and async context
- Improve error handling and validation
- Enhance documentation

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:16:35 +00:00
Devin AI
92e1877bf0 fix: Handle thread locks in Flow state serialization
- Add state serialization in Flow events to avoid pickling RLock objects
- Update event emission to use serialized state
- Add test case for Flow with thread locks

Fixes #2120

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 03:15:10 +00:00
João Moura
9d0a08206f Merge branch 'main' into devin/1739448321-fix-flow-state-pickling 2025-02-13 15:15:37 -03:00
Devin AI
b02e952c32 fix: Fix thread lock type checking in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:38:07 +00:00
Devin AI
ac703bafc8 fix: Fix type error in Flow state serialization
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:33:31 +00:00
Devin AI
fd70de34cf style: Fix import sorting in tests
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:20:55 +00:00
Devin AI
0e6689c19c test: Add comprehensive test for complex nested objects
- Add test for various thread-safe primitives
- Test nested dataclasses with complex state
- Verify serialization of async primitives

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:15:10 +00:00
Devin AI
cf7a26e009 refactor: Improve Flow state serialization with Pydantic core schema
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 13:14:31 +00:00
Devin AI
ed877467e1 refactor: Improve Flow state serialization
- Add BaseStateEvent class for common state processing
- Add state serialization caching for performance
- Add tests for nested locks and async context
- Improve error handling and validation
- Enhance documentation

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 12:14:59 +00:00
Devin AI
252095a668 fix: Handle thread locks in Flow state serialization
- Add state serialization in Flow events to avoid pickling RLock objects
- Update event emission to use serialized state
- Add test case for Flow with thread locks

Fixes #2120

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-13 12:07:26 +00:00
3 changed files with 587 additions and 14 deletions

View File

@@ -1,9 +1,12 @@
import asyncio
import copy
import dataclasses
import inspect
import logging
import threading
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generic,
@@ -572,15 +575,173 @@ class Flow(Generic[T], metaclass=FlowMeta):
k: v for k, v in model.__dict__.items() if not k.startswith("_")
}
# Create new instance of the same class
# Create new instance of the same class, handling thread locks
model_class = type(model)
return cast(T, model_class(**state_dict))
serialized_dict = self._serialize_value(state_dict)
return cast(T, model_class(**serialized_dict))
raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type[Union[threading.Lock, threading.RLock, threading.Semaphore, threading.Event, threading.Condition, asyncio.Lock, asyncio.Event, asyncio.Condition, asyncio.Semaphore]]]:
"""Get the type of a thread-safe primitive for recreation.
Args:
value: Any Python value to check
Returns:
The type of the thread-safe primitive, or None if not a primitive
"""
if hasattr(value, '_is_owned') and hasattr(value, 'acquire'):
# Get the actual types since some are factory functions
rlock_type = type(threading.RLock())
lock_type = type(threading.Lock())
semaphore_type = type(threading.Semaphore())
event_type = type(threading.Event())
condition_type = type(threading.Condition())
async_lock_type = type(asyncio.Lock())
async_event_type = type(asyncio.Event())
async_condition_type = type(asyncio.Condition())
async_semaphore_type = type(asyncio.Semaphore())
if isinstance(value, rlock_type):
return threading.RLock
elif isinstance(value, lock_type):
return threading.Lock
elif isinstance(value, semaphore_type):
return threading.Semaphore
elif isinstance(value, event_type):
return threading.Event
elif isinstance(value, condition_type):
return threading.Condition
elif isinstance(value, async_lock_type):
return asyncio.Lock
elif isinstance(value, async_event_type):
return asyncio.Event
elif isinstance(value, async_condition_type):
return asyncio.Condition
elif isinstance(value, async_semaphore_type):
return asyncio.Semaphore
return None
def _serialize_dataclass(self, value: Any) -> Union[Dict[str, Any], Any]:
"""Serialize a dataclass instance.
Args:
value: A dataclass instance
Returns:
A new instance of the dataclass with thread-safe primitives recreated
"""
if not hasattr(value, '__class__'):
return value
if hasattr(value, '__pydantic_validate__'):
return value.__pydantic_validate__()
# Get field values, handling thread-safe primitives
field_values = {}
for field in dataclasses.fields(value):
field_value = getattr(value, field.name)
primitive_type = self._get_thread_safe_primitive_type(field_value)
if primitive_type is not None:
field_values[field.name] = primitive_type()
else:
field_values[field.name] = self._serialize_value(field_value)
# Create new instance
return value.__class__(**field_values)
def _copy_state(self) -> T:
return copy.deepcopy(self._state)
"""Create a deep copy of the current state.
Returns:
A deep copy of the current state object
"""
return self._serialize_value(self._state)
def _serialize_value(self, value: Any) -> Any:
"""Recursively serialize a value, handling nested objects and locks.
Args:
value: Any Python value to serialize
Returns:
Serialized version of the value with thread-safe primitives handled
"""
# Handle None
if value is None:
return None
# Handle thread-safe primitives
primitive_type = self._get_thread_safe_primitive_type(value)
if primitive_type is not None:
return None
# Handle Pydantic models
if isinstance(value, BaseModel):
return type(value)(**{
k: self._serialize_value(v)
for k, v in value.model_dump().items()
})
# Handle dataclasses
if dataclasses.is_dataclass(value):
return self._serialize_dataclass(value)
# Handle dictionaries
if isinstance(value, dict):
return {
k: self._serialize_value(v)
for k, v in value.items()
}
# Handle lists, tuples, and sets
if isinstance(value, (list, tuple, set)):
serialized = [self._serialize_value(item) for item in value]
return (
serialized if isinstance(value, list)
else tuple(serialized) if isinstance(value, tuple)
else set(serialized)
)
# Handle other types
return value
def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]:
"""Serialize the current state for event emission.
This method handles the serialization of both BaseModel and dictionary states,
ensuring thread-safe copying of state data. Uses caching to improve performance
when state hasn't changed. Handles nested objects and locks recursively.
Returns:
Union[Dict[str, Any], BaseModel]: Serialized state as either a new BaseModel instance or dictionary
Raises:
ValueError: If state has invalid type
Exception: If serialization fails, logs error and returns empty dict
"""
try:
if not isinstance(self._state, (dict, BaseModel)):
raise ValueError(f"Invalid state type: {type(self._state)}")
if not hasattr(self, '_last_state_hash'):
self._last_state_hash = None
self._last_serialized_state = None
current_hash = hash(str(self._state))
if current_hash == self._last_state_hash:
return self._last_serialized_state
serialized = self._serialize_value(self._state)
self._last_state_hash = current_hash
self._last_serialized_state = serialized
return serialized
except Exception as e:
logger.error(f"State serialization failed: {str(e)}")
return cast(Dict[str, Any], {})
@property
def state(self) -> T:
@@ -712,7 +873,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
else:
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Union[Any, None]:
"""Start the flow execution.
Args:
@@ -816,12 +977,27 @@ class Flow(Generic[T], metaclass=FlowMeta):
@trace_flow_step
async def _execute_method(
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
self, method_name: str, method: Union[Callable[..., Any], Callable[..., Awaitable[Any]]], *args: Any, **kwargs: Any
) -> Any:
"""Execute a flow method with proper event handling and state management.
Args:
method_name: Name of the method to execute
method: The method to execute
*args: Positional arguments for the method
**kwargs: Keyword arguments for the method
Returns:
The result of the method execution
Raises:
Any exception that occurs during method execution
"""
try:
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
kwargs or {}
)
# Serialize state before event emission to avoid pickling issues
state_copy = self._serialize_state()
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {})
crewai_event_bus.emit(
self,
MethodExecutionStartedEvent(
@@ -829,7 +1005,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
method_name=method_name,
flow_name=self.__class__.__name__,
params=dumped_params,
state=self._copy_state(),
state=state_copy,
),
)
@@ -844,13 +1020,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._method_execution_counts.get(method_name, 0) + 1
)
# Serialize state after execution
state_copy = self._serialize_state()
crewai_event_bus.emit(
self,
MethodExecutionFinishedEvent(
type="method_execution_finished",
method_name=method_name,
flow_name=self.__class__.__name__,
state=self._copy_state(),
state=state_copy,
result=result,
),
)
@@ -918,7 +1097,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
await asyncio.gather(*tasks)
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
self, trigger_method: str, router_only: bool = False
) -> List[str]:
"""
Finds all methods that should be triggered based on conditions.
@@ -1028,7 +1207,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
traceback.print_exc()
def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info"
self, message: str, color: Optional[str] = "yellow", level: Optional[str] = "info"
) -> None:
"""Centralized logging method for flow events.
@@ -1053,7 +1232,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif level == "warning":
logger.warning(message)
def plot(self, filename: str = "crewai_flow") -> None:
def plot(self, filename: Optional[str] = "crewai_flow") -> None:
"""Plot the flow graph visualization.
Args:
filename: Optional name for the output file (default: "crewai_flow")
"""
crewai_event_bus.emit(
self,
FlowPlotEvent(

View File

@@ -0,0 +1,78 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel
@dataclass
class Event:
type: str
flow_name: str
timestamp: datetime = field(init=False)
def __post_init__(self):
self.timestamp = datetime.now()
@dataclass
class BaseStateEvent(Event):
"""Base class for events containing state data.
Handles common state serialization and validation logic to ensure thread-safe
state handling and proper type validation.
Raises:
ValueError: If state has invalid type
"""
state: Union[Dict[str, Any], BaseModel]
def __post_init__(self):
super().__post_init__()
self._process_state()
def _process_state(self):
"""Process and validate state data.
Ensures state is of valid type and creates a new instance of BaseModel
states to avoid thread lock serialization issues.
Raises:
ValueError: If state has invalid type
"""
if not isinstance(self.state, (dict, BaseModel)):
raise ValueError(f"Invalid state type: {type(self.state)}")
if isinstance(self.state, BaseModel):
self.state = type(self.state)(**self.state.model_dump())
@dataclass
class FlowStartedEvent(Event):
inputs: Optional[Dict[str, Any]] = None
@dataclass
class MethodExecutionStartedEvent(BaseStateEvent):
method_name: str
state: Union[Dict[str, Any], BaseModel]
params: Optional[Dict[str, Any]] = None
def __post_init__(self):
super().__post_init__()
self._process_state()
@dataclass
class MethodExecutionFinishedEvent(BaseStateEvent):
method_name: str
state: Union[Dict[str, Any], BaseModel]
result: Any = None
def __post_init__(self):
super().__post_init__()
self._process_state()
@dataclass
class FlowFinishedEvent(Event):
result: Optional[Any] = None

View File

@@ -2,9 +2,11 @@
import asyncio
from datetime import datetime
from typing import Dict, List, Optional, Set, Tuple
from uuid import uuid4
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.utilities.events import (
@@ -350,6 +352,315 @@ def test_flow_uuid_structured():
assert flow.state.message == "final"
def test_flow_with_thread_lock():
"""Test that Flow properly handles thread locks in state."""
import threading
class LockFlow(Flow):
def __init__(self):
super().__init__()
self.lock = threading.RLock()
self.counter = 0
@start()
async def step_1(self):
with self.lock:
self.counter += 1
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.lock:
self.counter += 1
return result + " -> step 2"
flow = LockFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
assert flow.counter == 2
def test_flow_with_nested_objects_and_locks():
"""Test that Flow properly handles nested objects containing locks."""
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass
class NestedState:
value: str
lock: threading.RLock = None
def __post_init__(self):
if self.lock is None:
self.lock = threading.RLock()
def __pydantic_validate__(self):
return {"value": self.value, "lock": threading.RLock()}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import (
str_schema,
with_info_plain_validator_function,
)
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(value["value"])
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
class ComplexState(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
nested: NestedState
items: List[NestedState]
mapping: Dict[str, NestedState]
optional: Optional[NestedState] = None
class ComplexStateFlow(Flow[ComplexState]):
def __init__(self):
self.initial_state = ComplexState(
name="test",
nested=NestedState("nested", threading.RLock()),
items=[
NestedState("item1", threading.RLock()),
NestedState("item2", threading.RLock())
],
mapping={
"key1": NestedState("map1", threading.RLock()),
"key2": NestedState("map2", threading.RLock())
},
optional=NestedState("optional", threading.RLock())
)
super().__init__()
@start()
async def step_1(self):
with self.state.nested.lock:
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.items[0].lock:
with self.state.mapping["key1"].lock:
with self.state.optional.lock:
return result + " -> step 2"
flow = ComplexStateFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
def test_flow_with_nested_locks():
"""Test that Flow properly handles nested thread locks."""
import threading
class NestedLockFlow(Flow):
def __init__(self):
super().__init__()
self.outer_lock = threading.RLock()
self.inner_lock = threading.RLock()
self.counter = 0
@start()
async def step_1(self):
with self.outer_lock:
with self.inner_lock:
self.counter += 1
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.outer_lock:
with self.inner_lock:
self.counter += 1
return result + " -> step 2"
flow = NestedLockFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
assert flow.counter == 2
@pytest.mark.asyncio
async def test_flow_with_async_locks():
"""Test that Flow properly handles locks in async context."""
import asyncio
import threading
class AsyncLockFlow(Flow):
def __init__(self):
super().__init__()
self.lock = threading.RLock()
self.async_lock = asyncio.Lock()
self.counter = 0
@start()
async def step_1(self):
async with self.async_lock:
with self.lock:
self.counter += 1
return "step 1"
@listen(step_1)
async def step_2(self, result):
async with self.async_lock:
with self.lock:
self.counter += 1
return result + " -> step 2"
flow = AsyncLockFlow()
result = await flow.kickoff_async()
assert result == "step 1 -> step 2"
assert flow.counter == 2
def test_flow_with_complex_nested_objects():
"""Test that Flow properly handles complex nested objects."""
import asyncio
import threading
from dataclasses import dataclass
@dataclass
class ThreadSafePrimitives:
thread_lock: threading.Lock
rlock: threading.RLock
semaphore: threading.Semaphore
event: threading.Event
async_lock: asyncio.Lock
async_event: asyncio.Event
def __post_init__(self):
self.thread_lock = self.thread_lock or threading.Lock()
self.rlock = self.rlock or threading.RLock()
self.semaphore = self.semaphore or threading.Semaphore()
self.event = self.event or threading.Event()
self.async_lock = self.async_lock or asyncio.Lock()
self.async_event = self.async_event or asyncio.Event()
def __pydantic_validate__(self):
return {
"thread_lock": None,
"rlock": None,
"semaphore": None,
"event": None,
"async_lock": None,
"async_event": None
}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import with_info_plain_validator_function
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(
thread_lock=None,
rlock=None,
semaphore=None,
event=None,
async_lock=None,
async_event=None
)
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
@dataclass
class NestedContainer:
name: str
primitives: ThreadSafePrimitives
items: List[ThreadSafePrimitives]
mapping: Dict[str, ThreadSafePrimitives]
optional: Optional[ThreadSafePrimitives]
def __post_init__(self):
self.primitives = self.primitives or ThreadSafePrimitives(None, None, None, None, None, None)
self.items = self.items or []
self.mapping = self.mapping or {}
def __pydantic_validate__(self):
return {
"name": self.name,
"primitives": self.primitives.__pydantic_validate__(),
"items": [item.__pydantic_validate__() for item in self.items],
"mapping": {k: v.__pydantic_validate__() for k, v in self.mapping.items()},
"optional": self.optional.__pydantic_validate__() if self.optional else None
}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import with_info_plain_validator_function
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(
name=value["name"],
primitives=ThreadSafePrimitives(None, None, None, None, None, None),
items=[],
mapping={},
optional=None
)
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
class ComplexState(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
nested: NestedContainer
items: List[NestedContainer]
mapping: Dict[str, NestedContainer]
optional: Optional[NestedContainer] = None
class ComplexStateFlow(Flow[ComplexState]):
def __init__(self):
primitives = ThreadSafePrimitives(
thread_lock=threading.Lock(),
rlock=threading.RLock(),
semaphore=threading.Semaphore(),
event=threading.Event(),
async_lock=asyncio.Lock(),
async_event=asyncio.Event()
)
container = NestedContainer(
name="test",
primitives=primitives,
items=[primitives],
mapping={"key": primitives},
optional=primitives
)
self.initial_state = ComplexState(
name="test",
nested=container,
items=[container],
mapping={"key": container},
optional=container
)
super().__init__()
@start()
async def step_1(self):
with self.state.nested.primitives.rlock:
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.items[0].primitives.rlock:
return result + " -> step 2"
flow = ComplexStateFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
def test_router_with_multiple_conditions():
"""Test a router that triggers when any of multiple steps complete (OR condition),
and another router that triggers only after all specified steps complete (AND condition).