refactor: Improve Flow state serialization with Pydantic core schema

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-13 13:14:31 +00:00
parent 93ec41225b
commit 3348de8db7
2 changed files with 220 additions and 10 deletions

View File

@@ -1,7 +1,9 @@
import asyncio import asyncio
import copy import copy
import dataclasses
import inspect import inspect
import logging import logging
import threading
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@@ -572,27 +574,164 @@ class Flow(Generic[T], metaclass=FlowMeta):
k: v for k, v in model.__dict__.items() if not k.startswith("_") k: v for k, v in model.__dict__.items() if not k.startswith("_")
} }
# Create new instance of the same class # Create new instance of the same class, handling thread locks
model_class = type(model) model_class = type(model)
return cast(T, model_class(**state_dict)) serialized_dict = self._serialize_value(state_dict)
return cast(T, model_class(**serialized_dict))
raise TypeError( raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
) )
def _get_thread_safe_primitive_type(self, value: Any) -> Optional[Type]:
"""Get the type of a thread-safe primitive for recreation.
Args:
value: Any Python value to check
Returns:
The type of the thread-safe primitive, or None if not a primitive
"""
if hasattr(value, '_is_owned') and hasattr(value, 'acquire'):
if isinstance(value, threading.RLock):
return threading.RLock
elif isinstance(value, threading.Lock):
return threading.Lock
elif isinstance(value, threading.Semaphore):
return threading.Semaphore
elif isinstance(value, threading.Event):
return threading.Event
elif isinstance(value, threading.Condition):
return threading.Condition
elif isinstance(value, asyncio.Lock):
return asyncio.Lock
elif isinstance(value, asyncio.Event):
return asyncio.Event
elif isinstance(value, asyncio.Condition):
return asyncio.Condition
elif isinstance(value, asyncio.Semaphore):
return asyncio.Semaphore
return None
def _serialize_dataclass(self, value: Any) -> Any:
"""Serialize a dataclass instance.
Args:
value: A dataclass instance
Returns:
A new instance of the dataclass with thread-safe primitives recreated
"""
if not hasattr(value, '__class__'):
return value
if hasattr(value, '__pydantic_validate__'):
return value.__pydantic_validate__()
# Get field values, handling thread-safe primitives
field_values = {}
for field in dataclasses.fields(value):
field_value = getattr(value, field.name)
primitive_type = self._get_thread_safe_primitive_type(field_value)
if primitive_type is not None:
field_values[field.name] = primitive_type()
else:
field_values[field.name] = self._serialize_value(field_value)
# Create new instance
return value.__class__(**field_values)
def _copy_state(self) -> T: def _copy_state(self) -> T:
"""Create a deep copy of the current state. """Create a deep copy of the current state.
Returns: Returns:
A deep copy of the current state object A deep copy of the current state object
""" """
return copy.deepcopy(self._state) return self._serialize_value(self._state)
def _serialize_value(self, value: Any) -> Any:
"""Recursively serialize a value, handling nested objects and locks.
Args:
value: Any Python value to serialize
Returns:
Serialized version of the value with thread-safe primitives handled
"""
# Handle None
if value is None:
return None
# Handle thread-safe primitives
primitive_type = self._get_thread_safe_primitive_type(value)
if primitive_type is not None:
return None
# Handle Pydantic models
if isinstance(value, BaseModel):
return type(value)(**{
k: self._serialize_value(v)
for k, v in value.model_dump().items()
})
# Handle dataclasses
if dataclasses.is_dataclass(value):
return self._serialize_dataclass(value)
# Handle dictionaries
if isinstance(value, dict):
return {
k: self._serialize_value(v)
for k, v in value.items()
}
# Handle lists, tuples, and sets
if isinstance(value, (list, tuple, set)):
serialized = [self._serialize_value(item) for item in value]
return (
serialized if isinstance(value, list)
else tuple(serialized) if isinstance(value, tuple)
else set(serialized)
)
# Handle other types
return value
def _serialize_value(self, value: Any) -> Any:
"""Recursively serialize a value, handling nested objects and locks.
Args:
value: Any Python value to serialize
Returns:
Serialized version of the value with locks properly handled
"""
if isinstance(value, BaseModel):
return type(value)(**{
k: self._serialize_value(v)
for k, v in value.model_dump().items()
})
elif isinstance(value, dict):
return {
k: self._serialize_value(v)
for k, v in value.items()
}
elif isinstance(value, list):
return [self._serialize_value(item) for item in value]
elif isinstance(value, tuple):
return tuple(self._serialize_value(item) for item in value)
elif isinstance(value, set):
return {self._serialize_value(item) for item in value}
elif hasattr(value, '_is_owned') and hasattr(value, 'acquire'):
# Skip thread locks and similar synchronization primitives
return None
return value
def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]: def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]:
"""Serialize the current state for event emission. """Serialize the current state for event emission.
This method handles the serialization of both BaseModel and dictionary states, This method handles the serialization of both BaseModel and dictionary states,
ensuring thread-safe copying of state data. Uses caching to improve performance ensuring thread-safe copying of state data. Uses caching to improve performance
when state hasn't changed. when state hasn't changed. Handles nested objects and locks recursively.
Returns: Returns:
Serialized state as either a new BaseModel instance or dictionary Serialized state as either a new BaseModel instance or dictionary
@@ -613,11 +752,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if current_hash == self._last_state_hash: if current_hash == self._last_state_hash:
return self._last_serialized_state return self._last_serialized_state
serialized = ( serialized = self._serialize_value(self._state)
type(self._state)(**self._state.model_dump())
if isinstance(self._state, BaseModel)
else dict(self._state)
)
self._last_state_hash = current_hash self._last_state_hash = current_hash
self._last_serialized_state = serialized self._last_serialized_state = serialized

View File

@@ -4,7 +4,8 @@ import asyncio
from datetime import datetime from datetime import datetime
import pytest import pytest
from pydantic import BaseModel from uuid import uuid4
from pydantic import BaseModel, Field
from crewai.flow.flow import Flow, and_, listen, or_, router, start from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.utilities.events import ( from crewai.utilities.events import (
@@ -379,6 +380,80 @@ def test_flow_with_thread_lock():
assert flow.counter == 2 assert flow.counter == 2
def test_flow_with_nested_objects_and_locks():
"""Test that Flow properly handles nested objects containing locks."""
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional
@dataclass
class NestedState:
value: str
lock: threading.RLock = None
def __post_init__(self):
if self.lock is None:
self.lock = threading.RLock()
def __pydantic_validate__(self):
return {"value": self.value, "lock": threading.RLock()}
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
from pydantic_core.core_schema import (
with_info_plain_validator_function,
str_schema,
)
def validate(value, _):
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls(value["value"])
raise ValueError(f"Invalid value type for {cls.__name__}")
return with_info_plain_validator_function(validate)
class ComplexState(BaseModel):
id: str = Field(default_factory=lambda: str(uuid4()))
name: str
nested: NestedState
items: List[NestedState]
mapping: Dict[str, NestedState]
optional: Optional[NestedState] = None
class ComplexStateFlow(Flow[ComplexState]):
def __init__(self):
self.initial_state = ComplexState(
name="test",
nested=NestedState("nested", threading.RLock()),
items=[
NestedState("item1", threading.RLock()),
NestedState("item2", threading.RLock())
],
mapping={
"key1": NestedState("map1", threading.RLock()),
"key2": NestedState("map2", threading.RLock())
},
optional=NestedState("optional", threading.RLock())
)
super().__init__()
@start()
async def step_1(self):
with self.state.nested.lock:
return "step 1"
@listen(step_1)
async def step_2(self, result):
with self.state.items[0].lock:
with self.state.mapping["key1"].lock:
with self.state.optional.lock:
return result + " -> step 2"
flow = ComplexStateFlow()
result = flow.kickoff()
assert result == "step 1 -> step 2"
def test_flow_with_nested_locks(): def test_flow_with_nested_locks():
"""Test that Flow properly handles nested thread locks.""" """Test that Flow properly handles nested thread locks."""
import threading import threading