mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
refactor: Improve Flow state serialization
- Add BaseStateEvent class for common state processing - Add state serialization caching for performance - Add tests for nested locks and async context - Improve error handling and validation - Enhance documentation Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -570,7 +570,51 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
def _copy_state(self) -> T:
|
||||
"""Create a deep copy of the current state.
|
||||
|
||||
Returns:
|
||||
A deep copy of the current state object
|
||||
"""
|
||||
return copy.deepcopy(self._state)
|
||||
|
||||
def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]:
|
||||
"""Serialize the current state for event emission.
|
||||
|
||||
This method handles the serialization of both BaseModel and dictionary states,
|
||||
ensuring thread-safe copying of state data. Uses caching to improve performance
|
||||
when state hasn't changed.
|
||||
|
||||
Returns:
|
||||
Serialized state as either a new BaseModel instance or dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If state has invalid type
|
||||
Exception: If serialization fails, logs error and returns empty dict
|
||||
"""
|
||||
try:
|
||||
if not isinstance(self._state, (dict, BaseModel)):
|
||||
raise ValueError(f"Invalid state type: {type(self._state)}")
|
||||
|
||||
if not hasattr(self, '_last_state_hash'):
|
||||
self._last_state_hash = None
|
||||
self._last_serialized_state = None
|
||||
|
||||
current_hash = hash(str(self._state))
|
||||
if current_hash == self._last_state_hash:
|
||||
return self._last_serialized_state
|
||||
|
||||
serialized = (
|
||||
type(self._state)(**self._state.model_dump())
|
||||
if isinstance(self._state, BaseModel)
|
||||
else dict(self._state)
|
||||
)
|
||||
|
||||
self._last_state_hash = current_hash
|
||||
self._last_serialized_state = serialized
|
||||
return serialized
|
||||
except Exception as e:
|
||||
logger.error(f"State serialization failed: {str(e)}")
|
||||
return {}
|
||||
|
||||
@property
|
||||
def state(self) -> T:
|
||||
@@ -808,11 +852,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self, method_name: str, method: Callable, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
# Serialize state before event emission to avoid pickling issues
|
||||
state_copy = (
|
||||
type(self._state)(**self._state.model_dump())
|
||||
if isinstance(self._state, BaseModel)
|
||||
else dict(self._state)
|
||||
)
|
||||
state_copy = self._serialize_state()
|
||||
|
||||
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {})
|
||||
self.event_emitter.send(
|
||||
@@ -837,11 +877,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
# Serialize state after execution
|
||||
state_copy = (
|
||||
type(self._state)(**self._state.model_dump())
|
||||
if isinstance(self._state, BaseModel)
|
||||
else dict(self._state)
|
||||
)
|
||||
state_copy = self._serialize_state()
|
||||
|
||||
self.event_emitter.send(
|
||||
self,
|
||||
|
||||
@@ -15,35 +15,62 @@ class Event:
|
||||
self.timestamp = datetime.now()
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseStateEvent(Event):
|
||||
"""Base class for events containing state data.
|
||||
|
||||
Handles common state serialization and validation logic to ensure thread-safe
|
||||
state handling and proper type validation.
|
||||
|
||||
Raises:
|
||||
ValueError: If state has invalid type
|
||||
"""
|
||||
state: Union[Dict[str, Any], BaseModel]
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self._process_state()
|
||||
|
||||
def _process_state(self):
|
||||
"""Process and validate state data.
|
||||
|
||||
Ensures state is of valid type and creates a new instance of BaseModel
|
||||
states to avoid thread lock serialization issues.
|
||||
|
||||
Raises:
|
||||
ValueError: If state has invalid type
|
||||
"""
|
||||
if not isinstance(self.state, (dict, BaseModel)):
|
||||
raise ValueError(f"Invalid state type: {type(self.state)}")
|
||||
if isinstance(self.state, BaseModel):
|
||||
self.state = type(self.state)(**self.state.model_dump())
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowStartedEvent(Event):
|
||||
inputs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodExecutionStartedEvent(Event):
|
||||
class MethodExecutionStartedEvent(BaseStateEvent):
|
||||
method_name: str
|
||||
state: Union[Dict[str, Any], BaseModel]
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Create a new instance of BaseModel state to avoid pickling issues
|
||||
if isinstance(self.state, BaseModel):
|
||||
self.state = type(self.state)(**self.state.model_dump())
|
||||
self._process_state()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodExecutionFinishedEvent(Event):
|
||||
class MethodExecutionFinishedEvent(BaseStateEvent):
|
||||
method_name: str
|
||||
state: Union[Dict[str, Any], BaseModel]
|
||||
result: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Create a new instance of BaseModel state to avoid pickling issues
|
||||
if isinstance(self.state, BaseModel):
|
||||
self.state = type(self.state)(**self.state.model_dump())
|
||||
self._process_state()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -377,6 +377,72 @@ def test_flow_with_thread_lock():
|
||||
assert flow.counter == 2
|
||||
|
||||
|
||||
def test_flow_with_nested_locks():
|
||||
"""Test that Flow properly handles nested thread locks."""
|
||||
import threading
|
||||
|
||||
class NestedLockFlow(Flow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.outer_lock = threading.RLock()
|
||||
self.inner_lock = threading.RLock()
|
||||
self.counter = 0
|
||||
|
||||
@start()
|
||||
async def step_1(self):
|
||||
with self.outer_lock:
|
||||
with self.inner_lock:
|
||||
self.counter += 1
|
||||
return "step 1"
|
||||
|
||||
@listen(step_1)
|
||||
async def step_2(self, result):
|
||||
with self.outer_lock:
|
||||
with self.inner_lock:
|
||||
self.counter += 1
|
||||
return result + " -> step 2"
|
||||
|
||||
flow = NestedLockFlow()
|
||||
result = flow.kickoff()
|
||||
|
||||
assert result == "step 1 -> step 2"
|
||||
assert flow.counter == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_with_async_locks():
|
||||
"""Test that Flow properly handles locks in async context."""
|
||||
import asyncio
|
||||
import threading
|
||||
|
||||
class AsyncLockFlow(Flow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lock = threading.RLock()
|
||||
self.async_lock = asyncio.Lock()
|
||||
self.counter = 0
|
||||
|
||||
@start()
|
||||
async def step_1(self):
|
||||
async with self.async_lock:
|
||||
with self.lock:
|
||||
self.counter += 1
|
||||
return "step 1"
|
||||
|
||||
@listen(step_1)
|
||||
async def step_2(self, result):
|
||||
async with self.async_lock:
|
||||
with self.lock:
|
||||
self.counter += 1
|
||||
return result + " -> step 2"
|
||||
|
||||
flow = AsyncLockFlow()
|
||||
result = await flow.kickoff_async()
|
||||
|
||||
assert result == "step 1 -> step 2"
|
||||
assert flow.counter == 2
|
||||
|
||||
|
||||
def test_router_with_multiple_conditions():
|
||||
"""Test a router that triggers when any of multiple steps complete (OR condition),
|
||||
and another router that triggers only after all specified steps complete (AND condition).
|
||||
|
||||
Reference in New Issue
Block a user