mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
fix: handle unpickleable values in flow state
Some checks failed
Some checks failed
This commit is contained in:
@@ -15,7 +15,6 @@ import logging
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Final,
|
|
||||||
Generic,
|
Generic,
|
||||||
Literal,
|
Literal,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
@@ -45,7 +44,7 @@ from crewai.events.types.flow_events import (
|
|||||||
MethodExecutionFinishedEvent,
|
MethodExecutionFinishedEvent,
|
||||||
MethodExecutionStartedEvent,
|
MethodExecutionStartedEvent,
|
||||||
)
|
)
|
||||||
from crewai.flow.visualization import build_flow_structure, render_interactive
|
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||||
from crewai.flow.flow_wrappers import (
|
from crewai.flow.flow_wrappers import (
|
||||||
FlowCondition,
|
FlowCondition,
|
||||||
FlowConditions,
|
FlowConditions,
|
||||||
@@ -58,18 +57,16 @@ from crewai.flow.flow_wrappers import (
|
|||||||
from crewai.flow.persistence.base import FlowPersistence
|
from crewai.flow.persistence.base import FlowPersistence
|
||||||
from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey
|
from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey
|
||||||
from crewai.flow.utils import (
|
from crewai.flow.utils import (
|
||||||
|
_extract_all_methods,
|
||||||
|
_normalize_condition,
|
||||||
get_possible_return_constants,
|
get_possible_return_constants,
|
||||||
is_flow_condition_dict,
|
is_flow_condition_dict,
|
||||||
is_flow_condition_list,
|
|
||||||
is_flow_method,
|
is_flow_method,
|
||||||
is_flow_method_callable,
|
is_flow_method_callable,
|
||||||
is_flow_method_name,
|
is_flow_method_name,
|
||||||
is_simple_flow_condition,
|
is_simple_flow_condition,
|
||||||
_extract_all_methods,
|
|
||||||
_extract_all_methods_recursive,
|
|
||||||
_normalize_condition,
|
|
||||||
)
|
)
|
||||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
from crewai.flow.visualization import build_flow_structure, render_interactive
|
||||||
from crewai.utilities.printer import Printer, PrinterColor
|
from crewai.utilities.printer import Printer, PrinterColor
|
||||||
|
|
||||||
|
|
||||||
@@ -495,7 +492,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
or should_auto_collect_first_time_traces()
|
or should_auto_collect_first_time_traces()
|
||||||
):
|
):
|
||||||
trace_listener = TraceCollectionListener()
|
trace_listener = TraceCollectionListener()
|
||||||
trace_listener.setup_listeners(crewai_event_bus) # type: ignore[no-untyped-call]
|
trace_listener.setup_listeners(crewai_event_bus)
|
||||||
# Apply any additional kwargs
|
# Apply any additional kwargs
|
||||||
if kwargs:
|
if kwargs:
|
||||||
self._initialize_state(kwargs)
|
self._initialize_state(kwargs)
|
||||||
@@ -601,7 +598,26 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _copy_state(self) -> T:
|
def _copy_state(self) -> T:
|
||||||
|
"""Create a copy of the current state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A copy of the current state
|
||||||
|
"""
|
||||||
|
if isinstance(self._state, BaseModel):
|
||||||
|
try:
|
||||||
|
return self._state.model_copy(deep=True)
|
||||||
|
except (TypeError, AttributeError):
|
||||||
|
try:
|
||||||
|
state_dict = self._state.model_dump()
|
||||||
|
model_class = type(self._state)
|
||||||
|
return model_class(**state_dict)
|
||||||
|
except Exception:
|
||||||
|
return self._state.model_copy(deep=False)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
return copy.deepcopy(self._state)
|
return copy.deepcopy(self._state)
|
||||||
|
except (TypeError, AttributeError):
|
||||||
|
return cast(T, self._state.copy())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> T:
|
def state(self) -> T:
|
||||||
@@ -926,8 +942,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
trace_listener = TraceCollectionListener()
|
trace_listener = TraceCollectionListener()
|
||||||
if trace_listener.batch_manager.batch_owner_type == "flow":
|
if trace_listener.batch_manager.batch_owner_type == "flow":
|
||||||
if trace_listener.first_time_handler.is_first_time:
|
if trace_listener.first_time_handler.is_first_time:
|
||||||
trace_listener.first_time_handler.mark_events_collected() # type: ignore[no-untyped-call]
|
trace_listener.first_time_handler.mark_events_collected()
|
||||||
trace_listener.first_time_handler.handle_execution_completion() # type: ignore[no-untyped-call]
|
trace_listener.first_time_handler.handle_execution_completion()
|
||||||
else:
|
else:
|
||||||
trace_listener.batch_manager.finalize_batch()
|
trace_listener.batch_manager.finalize_batch()
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -1384,3 +1385,110 @@ def test_mixed_sync_async_execution_order():
|
|||||||
]
|
]
|
||||||
|
|
||||||
assert execution_order == expected_order
|
assert execution_order == expected_order
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_copy_state_with_unpickleable_objects():
|
||||||
|
"""Test that _copy_state handles unpickleable objects like RLock.
|
||||||
|
|
||||||
|
Regression test for issue #3828: Flow should not crash when state contains
|
||||||
|
objects that cannot be deep copied (like threading.RLock).
|
||||||
|
"""
|
||||||
|
|
||||||
|
class StateWithRLock(BaseModel):
|
||||||
|
counter: int = 0
|
||||||
|
lock: Optional[threading.RLock] = None
|
||||||
|
|
||||||
|
class FlowWithRLock(Flow[StateWithRLock]):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
self.state.counter += 1
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self):
|
||||||
|
self.state.counter += 1
|
||||||
|
|
||||||
|
flow = FlowWithRLock(initial_state=StateWithRLock())
|
||||||
|
flow._state.lock = threading.RLock()
|
||||||
|
|
||||||
|
copied_state = flow._copy_state()
|
||||||
|
assert copied_state.counter == 0
|
||||||
|
assert copied_state.lock is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_copy_state_with_nested_unpickleable_objects():
|
||||||
|
"""Test that _copy_state handles unpickleable objects nested in containers.
|
||||||
|
|
||||||
|
Regression test for issue #3828: Verifies that unpickleable objects
|
||||||
|
nested inside dicts/lists in state don't cause crashes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class NestedState(BaseModel):
|
||||||
|
data: dict = {}
|
||||||
|
items: list = []
|
||||||
|
|
||||||
|
class FlowWithNestedUnpickleable(Flow[NestedState]):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
self.state.data["lock"] = threading.RLock()
|
||||||
|
self.state.data["value"] = 42
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self):
|
||||||
|
self.state.items.append(threading.Lock())
|
||||||
|
self.state.items.append("normal_value")
|
||||||
|
|
||||||
|
flow = FlowWithNestedUnpickleable(initial_state=NestedState())
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert flow.state.data["value"] == 42
|
||||||
|
assert len(flow.state.items) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_copy_state_without_unpickleable_objects():
|
||||||
|
"""Test that _copy_state still works normally with pickleable objects.
|
||||||
|
|
||||||
|
Ensures that the fallback logic doesn't break normal deep copy behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class NormalState(BaseModel):
|
||||||
|
counter: int = 0
|
||||||
|
data: str = ""
|
||||||
|
nested: dict = {}
|
||||||
|
|
||||||
|
class NormalFlow(Flow[NormalState]):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
self.state.counter = 5
|
||||||
|
self.state.data = "test"
|
||||||
|
self.state.nested = {"key": "value"}
|
||||||
|
|
||||||
|
flow = NormalFlow(initial_state=NormalState())
|
||||||
|
flow.state.counter = 10
|
||||||
|
flow.state.data = "modified"
|
||||||
|
flow.state.nested["key"] = "modified"
|
||||||
|
|
||||||
|
copied_state = flow._copy_state()
|
||||||
|
assert copied_state.counter == 10
|
||||||
|
assert copied_state.data == "modified"
|
||||||
|
assert copied_state.nested["key"] == "modified"
|
||||||
|
|
||||||
|
flow.state.nested["key"] = "changed_after_copy"
|
||||||
|
assert copied_state.nested["key"] == "modified"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_copy_state_with_dict_state():
|
||||||
|
"""Test that _copy_state works with dict-based states."""
|
||||||
|
|
||||||
|
class DictFlow(Flow[dict]):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
self.state["counter"] = 1
|
||||||
|
|
||||||
|
flow = DictFlow()
|
||||||
|
flow.state["test"] = "value"
|
||||||
|
|
||||||
|
copied_state = flow._copy_state()
|
||||||
|
assert copied_state["test"] == "value"
|
||||||
|
|
||||||
|
flow.state["test"] = "modified"
|
||||||
|
assert copied_state["test"] == "value"
|
||||||
|
|||||||
Reference in New Issue
Block a user