fix: handle unpickleable values in flow state
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit is contained in:
Greyson LaLonde
2025-11-06 01:29:21 -05:00
committed by GitHub
parent 7e6171d5bc
commit e4cc9a664c
2 changed files with 135 additions and 11 deletions

View File

@@ -15,7 +15,6 @@ import logging
from typing import (
Any,
ClassVar,
Final,
Generic,
Literal,
ParamSpec,
@@ -45,7 +44,7 @@ from crewai.events.types.flow_events import (
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.flow.visualization import build_flow_structure, render_interactive
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
from crewai.flow.flow_wrappers import (
FlowCondition,
FlowConditions,
@@ -58,18 +57,16 @@ from crewai.flow.flow_wrappers import (
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey
from crewai.flow.utils import (
_extract_all_methods,
_normalize_condition,
get_possible_return_constants,
is_flow_condition_dict,
is_flow_condition_list,
is_flow_method,
is_flow_method_callable,
is_flow_method_name,
is_simple_flow_condition,
_extract_all_methods,
_extract_all_methods_recursive,
_normalize_condition,
)
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
from crewai.flow.visualization import build_flow_structure, render_interactive
from crewai.utilities.printer import Printer, PrinterColor
@@ -495,7 +492,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
or should_auto_collect_first_time_traces()
):
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus) # type: ignore[no-untyped-call]
trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)
@@ -601,7 +598,26 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
def _copy_state(self) -> T:
return copy.deepcopy(self._state)
"""Create a copy of the current state.
Returns:
A copy of the current state
"""
if isinstance(self._state, BaseModel):
try:
return self._state.model_copy(deep=True)
except (TypeError, AttributeError):
try:
state_dict = self._state.model_dump()
model_class = type(self._state)
return model_class(**state_dict)
except Exception:
return self._state.model_copy(deep=False)
else:
try:
return copy.deepcopy(self._state)
except (TypeError, AttributeError):
return cast(T, self._state.copy())
@property
def state(self) -> T:
@@ -926,8 +942,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
trace_listener = TraceCollectionListener()
if trace_listener.batch_manager.batch_owner_type == "flow":
if trace_listener.first_time_handler.is_first_time:
trace_listener.first_time_handler.mark_events_collected() # type: ignore[no-untyped-call]
trace_listener.first_time_handler.handle_execution_completion() # type: ignore[no-untyped-call]
trace_listener.first_time_handler.mark_events_collected()
trace_listener.first_time_handler.handle_execution_completion()
else:
trace_listener.batch_manager.finalize_batch()

View File

@@ -3,6 +3,7 @@
import asyncio
import threading
from datetime import datetime
from typing import Optional
import pytest
from pydantic import BaseModel
@@ -1384,3 +1385,110 @@ def test_mixed_sync_async_execution_order():
]
assert execution_order == expected_order
def test_flow_copy_state_with_unpickleable_objects():
"""Test that _copy_state handles unpickleable objects like RLock.
Regression test for issue #3828: Flow should not crash when state contains
objects that cannot be deep copied (like threading.RLock).
"""
class StateWithRLock(BaseModel):
counter: int = 0
lock: Optional[threading.RLock] = None
class FlowWithRLock(Flow[StateWithRLock]):
@start()
def step_1(self):
self.state.counter += 1
@listen(step_1)
def step_2(self):
self.state.counter += 1
flow = FlowWithRLock(initial_state=StateWithRLock())
flow._state.lock = threading.RLock()
copied_state = flow._copy_state()
assert copied_state.counter == 0
assert copied_state.lock is not None
def test_flow_copy_state_with_nested_unpickleable_objects():
"""Test that _copy_state handles unpickleable objects nested in containers.
Regression test for issue #3828: Verifies that unpickleable objects
nested inside dicts/lists in state don't cause crashes.
"""
class NestedState(BaseModel):
data: dict = {}
items: list = []
class FlowWithNestedUnpickleable(Flow[NestedState]):
@start()
def step_1(self):
self.state.data["lock"] = threading.RLock()
self.state.data["value"] = 42
@listen(step_1)
def step_2(self):
self.state.items.append(threading.Lock())
self.state.items.append("normal_value")
flow = FlowWithNestedUnpickleable(initial_state=NestedState())
flow.kickoff()
assert flow.state.data["value"] == 42
assert len(flow.state.items) == 2
def test_flow_copy_state_without_unpickleable_objects():
"""Test that _copy_state still works normally with pickleable objects.
Ensures that the fallback logic doesn't break normal deep copy behavior.
"""
class NormalState(BaseModel):
counter: int = 0
data: str = ""
nested: dict = {}
class NormalFlow(Flow[NormalState]):
@start()
def step_1(self):
self.state.counter = 5
self.state.data = "test"
self.state.nested = {"key": "value"}
flow = NormalFlow(initial_state=NormalState())
flow.state.counter = 10
flow.state.data = "modified"
flow.state.nested["key"] = "modified"
copied_state = flow._copy_state()
assert copied_state.counter == 10
assert copied_state.data == "modified"
assert copied_state.nested["key"] == "modified"
flow.state.nested["key"] = "changed_after_copy"
assert copied_state.nested["key"] == "modified"
def test_flow_copy_state_with_dict_state():
"""Test that _copy_state works with dict-based states."""
class DictFlow(Flow[dict]):
@start()
def step_1(self):
self.state["counter"] = 1
flow = DictFlow()
flow.state["test"] = "value"
copied_state = flow._copy_state()
assert copied_state["test"] == "value"
flow.state["test"] = "modified"
assert copied_state["test"] == "value"