mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
refactor: Improve Flow state serialization
- Add BaseStateEvent class for common state processing - Add state serialization caching for performance - Add tests for nested locks and async context - Improve error handling and validation - Enhance documentation Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -580,7 +580,51 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _copy_state(self) -> T:
|
def _copy_state(self) -> T:
|
||||||
|
"""Create a deep copy of the current state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A deep copy of the current state object
|
||||||
|
"""
|
||||||
return copy.deepcopy(self._state)
|
return copy.deepcopy(self._state)
|
||||||
|
|
||||||
|
def _serialize_state(self) -> Union[Dict[str, Any], BaseModel]:
|
||||||
|
"""Serialize the current state for event emission.
|
||||||
|
|
||||||
|
This method handles the serialization of both BaseModel and dictionary states,
|
||||||
|
ensuring thread-safe copying of state data. Uses caching to improve performance
|
||||||
|
when state hasn't changed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Serialized state as either a new BaseModel instance or dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If state has invalid type
|
||||||
|
Exception: If serialization fails, logs error and returns empty dict
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not isinstance(self._state, (dict, BaseModel)):
|
||||||
|
raise ValueError(f"Invalid state type: {type(self._state)}")
|
||||||
|
|
||||||
|
if not hasattr(self, '_last_state_hash'):
|
||||||
|
self._last_state_hash = None
|
||||||
|
self._last_serialized_state = None
|
||||||
|
|
||||||
|
current_hash = hash(str(self._state))
|
||||||
|
if current_hash == self._last_state_hash:
|
||||||
|
return self._last_serialized_state
|
||||||
|
|
||||||
|
serialized = (
|
||||||
|
type(self._state)(**self._state.model_dump())
|
||||||
|
if isinstance(self._state, BaseModel)
|
||||||
|
else dict(self._state)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._last_state_hash = current_hash
|
||||||
|
self._last_serialized_state = serialized
|
||||||
|
return serialized
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"State serialization failed: {str(e)}")
|
||||||
|
return {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> T:
|
def state(self) -> T:
|
||||||
@@ -820,11 +864,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
try:
|
try:
|
||||||
# Serialize state before event emission to avoid pickling issues
|
# Serialize state before event emission to avoid pickling issues
|
||||||
state_copy = (
|
state_copy = self._serialize_state()
|
||||||
type(self._state)(**self._state.model_dump())
|
|
||||||
if isinstance(self._state, BaseModel)
|
|
||||||
else dict(self._state)
|
|
||||||
)
|
|
||||||
|
|
||||||
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {})
|
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {})
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
@@ -849,12 +889,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self._method_execution_counts.get(method_name, 0) + 1
|
self._method_execution_counts.get(method_name, 0) + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Serialize state after execution to avoid pickling issues
|
# Serialize state after execution
|
||||||
state_copy = (
|
state_copy = self._serialize_state()
|
||||||
type(self._state)(**self._state.model_dump())
|
|
||||||
if isinstance(self._state, BaseModel)
|
|
||||||
else dict(self._state)
|
|
||||||
)
|
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -867,13 +903,83 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
MethodExecutionFailedEvent(
|
||||||
|
type="method_execution_failed",
|
||||||
|
method_name=method_name,
|
||||||
|
flow_name=self.__class__.__name__,
|
||||||
|
error=e,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {})
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
MethodExecutionStartedEvent(
|
||||||
|
type="method_execution_started",
|
||||||
|
method_name=method_name,
|
||||||
|
flow_name=self.__class__.__name__,
|
||||||
|
params=dumped_params,
|
||||||
|
state=state_copy,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
|
result = (
|
||||||
|
await method(*args, **kwargs)
|
||||||
|
if asyncio.iscoroutinefunction(method)
|
||||||
|
else method(*args, **kwargs)
|
||||||
|
)
|
||||||
|
||||||| parent of ed877467 (refactor: Improve Flow state serialization)
|
||||||
|
# Serialize state after execution
|
||||||
|
state_copy = (
|
||||||
|
type(self._state)(**self._state.model_dump())
|
||||||
|
if isinstance(self._state, BaseModel)
|
||||||
|
else dict(self._state)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.event_emitter.send(
|
||||||
|
self,
|
||||||
|
event=MethodExecutionFinishedEvent(
|
||||||
|
type="method_execution_finished",
|
||||||
|
method_name=method_name,
|
||||||
|
flow_name=self.__class__.__name__,
|
||||||
|
state=state_copy,
|
||||||
|
result=result,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
=======
|
||||||
|
# Serialize state after execution
|
||||||
|
state_copy = self._serialize_state()
|
||||||
|
|
||||||
|
self.event_emitter.send(
|
||||||
|
self,
|
||||||
|
event=MethodExecutionFinishedEvent(
|
||||||
|
type="method_execution_finished",
|
||||||
|
method_name=method_name,
|
||||||
|
flow_name=self.__class__.__name__,
|
||||||
|
state=state_copy,
|
||||||
|
result=result,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
>>>>>>> ed877467 (refactor: Improve Flow state serialization)
|
||||||
|
|
||||||
|
self._method_outputs.append(result)
|
||||||
|
self._method_execution_counts[method_name] = (
|
||||||
|
self._method_execution_counts.get(method_name, 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
MethodExecutionFinishedEvent(
|
MethodExecutionFinishedEvent(
|
||||||
type="method_execution_finished",
|
type="method_execution_finished",
|
||||||
method_name=method_name,
|
method_name=method_name,
|
||||||
flow_name=self.__class__.__name__,
|
flow_name=self.__class__.__name__,
|
||||||
state=self._copy_state(),
|
state=state_copy,
|
||||||
result=result,
|
result=result,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -379,6 +379,72 @@ def test_flow_with_thread_lock():
|
|||||||
assert flow.counter == 2
|
assert flow.counter == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_with_nested_locks():
|
||||||
|
"""Test that Flow properly handles nested thread locks."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
class NestedLockFlow(Flow):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.outer_lock = threading.RLock()
|
||||||
|
self.inner_lock = threading.RLock()
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
@start()
|
||||||
|
async def step_1(self):
|
||||||
|
with self.outer_lock:
|
||||||
|
with self.inner_lock:
|
||||||
|
self.counter += 1
|
||||||
|
return "step 1"
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
async def step_2(self, result):
|
||||||
|
with self.outer_lock:
|
||||||
|
with self.inner_lock:
|
||||||
|
self.counter += 1
|
||||||
|
return result + " -> step 2"
|
||||||
|
|
||||||
|
flow = NestedLockFlow()
|
||||||
|
result = flow.kickoff()
|
||||||
|
|
||||||
|
assert result == "step 1 -> step 2"
|
||||||
|
assert flow.counter == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flow_with_async_locks():
|
||||||
|
"""Test that Flow properly handles locks in async context."""
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
|
||||||
|
class AsyncLockFlow(Flow):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.lock = threading.RLock()
|
||||||
|
self.async_lock = asyncio.Lock()
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
@start()
|
||||||
|
async def step_1(self):
|
||||||
|
async with self.async_lock:
|
||||||
|
with self.lock:
|
||||||
|
self.counter += 1
|
||||||
|
return "step 1"
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
async def step_2(self, result):
|
||||||
|
async with self.async_lock:
|
||||||
|
with self.lock:
|
||||||
|
self.counter += 1
|
||||||
|
return result + " -> step 2"
|
||||||
|
|
||||||
|
flow = AsyncLockFlow()
|
||||||
|
result = await flow.kickoff_async()
|
||||||
|
|
||||||
|
assert result == "step 1 -> step 2"
|
||||||
|
assert flow.counter == 2
|
||||||
|
|
||||||
|
|
||||||
def test_router_with_multiple_conditions():
|
def test_router_with_multiple_conditions():
|
||||||
"""Test a router that triggers when any of multiple steps complete (OR condition),
|
"""Test a router that triggers when any of multiple steps complete (OR condition),
|
||||||
and another router that triggers only after all specified steps complete (AND condition).
|
and another router that triggers only after all specified steps complete (AND condition).
|
||||||
|
|||||||
Reference in New Issue
Block a user