fix: handle unpickleable values in flow state
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit is contained in:
Greyson LaLonde
2025-11-06 01:29:21 -05:00
committed by GitHub
parent 7e6171d5bc
commit e4cc9a664c
2 changed files with 135 additions and 11 deletions

View File

@@ -15,7 +15,6 @@ import logging
from typing import ( from typing import (
Any, Any,
ClassVar, ClassVar,
Final,
Generic, Generic,
Literal, Literal,
ParamSpec, ParamSpec,
@@ -45,7 +44,7 @@ from crewai.events.types.flow_events import (
MethodExecutionFinishedEvent, MethodExecutionFinishedEvent,
MethodExecutionStartedEvent, MethodExecutionStartedEvent,
) )
from crewai.flow.visualization import build_flow_structure, render_interactive from crewai.flow.constants import AND_CONDITION, OR_CONDITION
from crewai.flow.flow_wrappers import ( from crewai.flow.flow_wrappers import (
FlowCondition, FlowCondition,
FlowConditions, FlowConditions,
@@ -58,18 +57,16 @@ from crewai.flow.flow_wrappers import (
from crewai.flow.persistence.base import FlowPersistence from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey
from crewai.flow.utils import ( from crewai.flow.utils import (
_extract_all_methods,
_normalize_condition,
get_possible_return_constants, get_possible_return_constants,
is_flow_condition_dict, is_flow_condition_dict,
is_flow_condition_list,
is_flow_method, is_flow_method,
is_flow_method_callable, is_flow_method_callable,
is_flow_method_name, is_flow_method_name,
is_simple_flow_condition, is_simple_flow_condition,
_extract_all_methods,
_extract_all_methods_recursive,
_normalize_condition,
) )
from crewai.flow.constants import AND_CONDITION, OR_CONDITION from crewai.flow.visualization import build_flow_structure, render_interactive
from crewai.utilities.printer import Printer, PrinterColor from crewai.utilities.printer import Printer, PrinterColor
@@ -495,7 +492,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
or should_auto_collect_first_time_traces() or should_auto_collect_first_time_traces()
): ):
trace_listener = TraceCollectionListener() trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus) # type: ignore[no-untyped-call] trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs # Apply any additional kwargs
if kwargs: if kwargs:
self._initialize_state(kwargs) self._initialize_state(kwargs)
@@ -601,7 +598,26 @@ class Flow(Generic[T], metaclass=FlowMeta):
) )
def _copy_state(self) -> T: def _copy_state(self) -> T:
"""Create a copy of the current state.
Returns:
A copy of the current state
"""
if isinstance(self._state, BaseModel):
try:
return self._state.model_copy(deep=True)
except (TypeError, AttributeError):
try:
state_dict = self._state.model_dump()
model_class = type(self._state)
return model_class(**state_dict)
except Exception:
return self._state.model_copy(deep=False)
else:
try:
return copy.deepcopy(self._state) return copy.deepcopy(self._state)
except (TypeError, AttributeError):
return cast(T, self._state.copy())
@property @property
def state(self) -> T: def state(self) -> T:
@@ -926,8 +942,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
trace_listener = TraceCollectionListener() trace_listener = TraceCollectionListener()
if trace_listener.batch_manager.batch_owner_type == "flow": if trace_listener.batch_manager.batch_owner_type == "flow":
if trace_listener.first_time_handler.is_first_time: if trace_listener.first_time_handler.is_first_time:
trace_listener.first_time_handler.mark_events_collected() # type: ignore[no-untyped-call] trace_listener.first_time_handler.mark_events_collected()
trace_listener.first_time_handler.handle_execution_completion() # type: ignore[no-untyped-call] trace_listener.first_time_handler.handle_execution_completion()
else: else:
trace_listener.batch_manager.finalize_batch() trace_listener.batch_manager.finalize_batch()

View File

@@ -3,6 +3,7 @@
import asyncio import asyncio
import threading import threading
from datetime import datetime from datetime import datetime
from typing import Optional
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel
@@ -1384,3 +1385,110 @@ def test_mixed_sync_async_execution_order():
] ]
assert execution_order == expected_order assert execution_order == expected_order
def test_flow_copy_state_with_unpickleable_objects():
"""Test that _copy_state handles unpickleable objects like RLock.
Regression test for issue #3828: Flow should not crash when state contains
objects that cannot be deep copied (like threading.RLock).
"""
class StateWithRLock(BaseModel):
counter: int = 0
lock: Optional[threading.RLock] = None
class FlowWithRLock(Flow[StateWithRLock]):
@start()
def step_1(self):
self.state.counter += 1
@listen(step_1)
def step_2(self):
self.state.counter += 1
flow = FlowWithRLock(initial_state=StateWithRLock())
flow._state.lock = threading.RLock()
copied_state = flow._copy_state()
assert copied_state.counter == 0
assert copied_state.lock is not None
def test_flow_copy_state_with_nested_unpickleable_objects():
"""Test that _copy_state handles unpickleable objects nested in containers.
Regression test for issue #3828: Verifies that unpickleable objects
nested inside dicts/lists in state don't cause crashes.
"""
class NestedState(BaseModel):
data: dict = {}
items: list = []
class FlowWithNestedUnpickleable(Flow[NestedState]):
@start()
def step_1(self):
self.state.data["lock"] = threading.RLock()
self.state.data["value"] = 42
@listen(step_1)
def step_2(self):
self.state.items.append(threading.Lock())
self.state.items.append("normal_value")
flow = FlowWithNestedUnpickleable(initial_state=NestedState())
flow.kickoff()
assert flow.state.data["value"] == 42
assert len(flow.state.items) == 2
def test_flow_copy_state_without_unpickleable_objects():
"""Test that _copy_state still works normally with pickleable objects.
Ensures that the fallback logic doesn't break normal deep copy behavior.
"""
class NormalState(BaseModel):
counter: int = 0
data: str = ""
nested: dict = {}
class NormalFlow(Flow[NormalState]):
@start()
def step_1(self):
self.state.counter = 5
self.state.data = "test"
self.state.nested = {"key": "value"}
flow = NormalFlow(initial_state=NormalState())
flow.state.counter = 10
flow.state.data = "modified"
flow.state.nested["key"] = "modified"
copied_state = flow._copy_state()
assert copied_state.counter == 10
assert copied_state.data == "modified"
assert copied_state.nested["key"] == "modified"
flow.state.nested["key"] = "changed_after_copy"
assert copied_state.nested["key"] == "modified"
def test_flow_copy_state_with_dict_state():
"""Test that _copy_state works with dict-based states."""
class DictFlow(Flow[dict]):
@start()
def step_1(self):
self.state["counter"] = 1
flow = DictFlow()
flow.state["test"] = "value"
copied_state = flow._copy_state()
assert copied_state["test"] == "value"
flow.state["test"] = "modified"
assert copied_state["test"] == "value"