refactor: Improve Flow state serialization

- Add BaseStateEvent class for common state processing
- Add state serialization caching for performance
- Add tests for nested locks and async context
- Improve error handling and validation
- Enhance documentation

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-13 12:14:59 +00:00
parent 252095a668
commit ed877467e1
3 changed files with 147 additions and 18 deletions

View File

@@ -570,7 +570,51 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
def _copy_state(self) -> T:
"""Create a deep copy of the current state.
Returns:
A deep copy of the current state object
"""
return copy.deepcopy(self._state)
def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]:
"""Serialize the current state for event emission.
This method handles the serialization of both BaseModel and dictionary states,
ensuring thread-safe copying of state data. Uses caching to improve performance
when state hasn't changed.
Returns:
Serialized state as either a new BaseModel instance or dictionary
Raises:
ValueError: If state has invalid type
Exception: If serialization fails, logs error and returns empty dict
"""
try:
if not isinstance(self._state, (dict, BaseModel)):
raise ValueError(f"Invalid state type: {type(self._state)}")
if not hasattr(self, '_last_state_hash'):
self._last_state_hash = None
self._last_serialized_state = None
current_hash = hash(str(self._state))
if current_hash == self._last_state_hash:
return self._last_serialized_state
serialized = (
type(self._state)(**self._state.model_dump())
if isinstance(self._state, BaseModel)
else dict(self._state)
)
self._last_state_hash = current_hash
self._last_serialized_state = serialized
return serialized
except Exception as e:
logger.error(f"State serialization failed: {str(e)}")
return {}
@property
def state(self) -> T:
@@ -808,11 +852,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
) -> Any:
# Serialize state before event emission to avoid pickling issues
state_copy = (
type(self._state)(**self._state.model_dump())
if isinstance(self._state, BaseModel)
else dict(self._state)
)
state_copy = self._serialize_state()
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {})
self.event_emitter.send(
@@ -837,11 +877,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
# Serialize state after execution
state_copy = (
type(self._state)(**self._state.model_dump())
if isinstance(self._state, BaseModel)
else dict(self._state)
)
state_copy = self._serialize_state()
self.event_emitter.send(
self,

View File

@@ -15,35 +15,62 @@ class Event:
self.timestamp = datetime.now()
@dataclass
class BaseStateEvent(Event):
"""Base class for events containing state data.
Handles common state serialization and validation logic to ensure thread-safe
state handling and proper type validation.
Raises:
ValueError: If state has invalid type
"""
state: Union[Dict[str, Any], BaseModel]
def __post_init__(self):
super().__post_init__()
self._process_state()
def _process_state(self):
"""Process and validate state data.
Ensures state is of valid type and creates a new instance of BaseModel
states to avoid thread lock serialization issues.
Raises:
ValueError: If state has invalid type
"""
if not isinstance(self.state, (dict, BaseModel)):
raise ValueError(f"Invalid state type: {type(self.state)}")
if isinstance(self.state, BaseModel):
self.state = type(self.state)(**self.state.model_dump())
@dataclass
class FlowStartedEvent(Event):
inputs: Optional[Dict[str, Any]] = None
@dataclass
class MethodExecutionStartedEvent(Event):
class MethodExecutionStartedEvent(BaseStateEvent):
method_name: str
state: Union[Dict[str, Any], BaseModel]
params: Optional[Dict[str, Any]] = None
def __post_init__(self):
super().__post_init__()
# Create a new instance of BaseModel state to avoid pickling issues
if isinstance(self.state, BaseModel):
self.state = type(self.state)(**self.state.model_dump())
self._process_state()
@dataclass
class MethodExecutionFinishedEvent(Event):
class MethodExecutionFinishedEvent(BaseStateEvent):
method_name: str
state: Union[Dict[str, Any], BaseModel]
result: Any = None
def __post_init__(self):
super().__post_init__()
# Create a new instance of BaseModel state to avoid pickling issues
if isinstance(self.state, BaseModel):
self.state = type(self.state)(**self.state.model_dump())
self._process_state()
@dataclass