mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 14:09:24 +00:00
Remove StateProxy from flow state access (#6327)
`StateProxy` looked like a thread-safety boundary, but it only protected a small slice of state operations. Some examples of operations that were not covered: - `self.state.counter += 1`, `self.state["counter"] += 1` (increments) - `self.state.user.profile.score += 1` (nested object mutations) - `self.state.config["limits"]["max"] = 10` (mutation through model fields) - `self.state.items[0].status = "done"` (list/container mutations) This commit decided to remove it completely for simplicity and performance: - Simpler runtime code - attr read: 24x faster, attr write: 27x faster, list append: 19x faster (local benchmark) - Clearer concurrency contract (lifecycle locks remain, but arbitrary shared state mutation is not presented as thread-safe)
This commit is contained in:
@@ -54,7 +54,7 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow, StateProxy, listen, or_, router, start
|
||||
from crewai.flow.flow import Flow, listen, or_, router, start
|
||||
from crewai.flow.types import FlowMethodName
|
||||
from crewai.hooks.llm_hooks import (
|
||||
get_after_llm_call_hooks,
|
||||
@@ -276,11 +276,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
"""
|
||||
return self.llm.supports_stop_words() if self.llm else False
|
||||
|
||||
@property
|
||||
def state(self) -> AgentExecutorState:
|
||||
"""Get thread-safe state proxy."""
|
||||
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
|
||||
|
||||
@property # type: ignore[misc]
|
||||
def iterations(self) -> int:
|
||||
"""Compatibility property for mixin - returns state iterations."""
|
||||
|
||||
@@ -24,9 +24,6 @@ from crewai.flow.runtime import (
|
||||
Flow as RuntimeFlow,
|
||||
FlowMeta,
|
||||
FlowState,
|
||||
LockedDictProxy,
|
||||
LockedListProxy,
|
||||
StateProxy,
|
||||
)
|
||||
|
||||
|
||||
@@ -42,9 +39,6 @@ __all__ = [
|
||||
"Flow",
|
||||
"FlowMeta",
|
||||
"FlowState",
|
||||
"LockedDictProxy",
|
||||
"LockedListProxy",
|
||||
"StateProxy",
|
||||
"and_",
|
||||
"listen",
|
||||
"or_",
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Flow Runtime: the engine that executes a Flow.
|
||||
|
||||
Provides the ``Flow`` class (kickoff/resume/listener dispatch), the
|
||||
``FlowMeta`` metaclass, and the thread-safe state proxies. Flows
|
||||
authored with the Python DSL (see ``dsl``) are described by a Flow
|
||||
Provides the ``Flow`` class (kickoff/resume/listener dispatch) and the
|
||||
``FlowMeta`` metaclass. Flows authored with the Python DSL (see ``dsl``)
|
||||
are described by a Flow
|
||||
Structure (see ``flow_definition``) and executed here.
|
||||
"""
|
||||
|
||||
@@ -11,12 +11,8 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections.abc import (
|
||||
Callable,
|
||||
ItemsView,
|
||||
Iterable,
|
||||
Iterator,
|
||||
KeysView,
|
||||
Sequence,
|
||||
ValuesView,
|
||||
)
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
import contextvars
|
||||
@@ -35,10 +31,8 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
SupportsIndex,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -383,304 +377,6 @@ R = TypeVar("R")
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
||||
"""Thread-safe proxy for list operations.
|
||||
|
||||
Subclasses ``list`` so that ``isinstance(proxy, list)`` returns True,
|
||||
which is required by libraries like LanceDB and Pydantic that do strict
|
||||
type checks. All mutations go through the lock; reads delegate to the
|
||||
underlying list.
|
||||
"""
|
||||
|
||||
def __init__(self, lst: list[T], lock: threading.Lock) -> None:
|
||||
super().__init__() # empty builtin list; all access goes through self._list
|
||||
self._list = lst
|
||||
self._lock = lock
|
||||
|
||||
def append(self, item: T) -> None:
|
||||
with self._lock:
|
||||
self._list.append(item)
|
||||
|
||||
def extend(self, items: Iterable[T]) -> None:
|
||||
with self._lock:
|
||||
self._list.extend(items)
|
||||
|
||||
def insert(self, index: SupportsIndex, item: T) -> None:
|
||||
with self._lock:
|
||||
self._list.insert(index, item)
|
||||
|
||||
def remove(self, item: T) -> None:
|
||||
with self._lock:
|
||||
self._list.remove(item)
|
||||
|
||||
def pop(self, index: SupportsIndex = -1) -> T:
|
||||
with self._lock:
|
||||
return self._list.pop(index)
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._list.clear()
|
||||
|
||||
@overload
|
||||
def __setitem__(self, index: SupportsIndex, value: T) -> None: ...
|
||||
@overload
|
||||
def __setitem__(self, index: slice, value: Iterable[T]) -> None: ...
|
||||
def __setitem__(self, index: Any, value: Any) -> None:
|
||||
with self._lock:
|
||||
self._list[index] = value
|
||||
|
||||
def __delitem__(self, index: SupportsIndex | slice) -> None:
|
||||
with self._lock:
|
||||
del self._list[index]
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: SupportsIndex) -> T: ...
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> list[T]: ...
|
||||
def __getitem__(self, index: Any) -> Any:
|
||||
return self._list[index]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._list)
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
return iter(self._list)
|
||||
|
||||
def __contains__(self, item: object) -> bool:
|
||||
return item in self._list
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._list)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._list)
|
||||
|
||||
def index(
|
||||
self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None
|
||||
) -> int:
|
||||
if stop is None:
|
||||
return self._list.index(value, start)
|
||||
return self._list.index(value, start, stop)
|
||||
|
||||
def count(self, value: T) -> int:
|
||||
return self._list.count(value)
|
||||
|
||||
def sort(self, *, key: Any = None, reverse: bool = False) -> None:
|
||||
with self._lock:
|
||||
self._list.sort(key=key, reverse=reverse)
|
||||
|
||||
def reverse(self) -> None:
|
||||
with self._lock:
|
||||
self._list.reverse()
|
||||
|
||||
def copy(self) -> list[T]:
|
||||
return self._list.copy()
|
||||
|
||||
def __add__(self, other: list[T]) -> list[T]: # type: ignore[override]
|
||||
return self._list + other
|
||||
|
||||
def __radd__(self, other: list[T]) -> list[T]:
|
||||
return other + self._list
|
||||
|
||||
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: # type: ignore[override]
|
||||
with self._lock:
|
||||
self._list += list(other)
|
||||
return self
|
||||
|
||||
def __mul__(self, n: SupportsIndex) -> list[T]:
|
||||
return self._list * n
|
||||
|
||||
def __rmul__(self, n: SupportsIndex) -> list[T]:
|
||||
return self._list * n
|
||||
|
||||
def __imul__(self, n: SupportsIndex) -> LockedListProxy[T]:
|
||||
with self._lock:
|
||||
self._list *= n
|
||||
return self
|
||||
|
||||
def __reversed__(self) -> Iterator[T]:
|
||||
return reversed(self._list)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare based on the underlying list contents."""
|
||||
if isinstance(other, LockedListProxy):
|
||||
# Avoid deadlocks by acquiring locks in a consistent order.
|
||||
first, second = (self, other) if id(self) <= id(other) else (other, self)
|
||||
with first._lock:
|
||||
with second._lock:
|
||||
return first._list == second._list
|
||||
with self._lock:
|
||||
return self._list == other
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
||||
"""Thread-safe proxy for dict operations.
|
||||
|
||||
Subclasses ``dict`` so that ``isinstance(proxy, dict)`` returns True,
|
||||
which is required by libraries like Pydantic that do strict type checks.
|
||||
All mutations go through the lock; reads delegate to the underlying dict.
|
||||
"""
|
||||
|
||||
def __init__(self, d: dict[str, T], lock: threading.Lock) -> None:
|
||||
super().__init__() # empty builtin dict; all access goes through self._dict
|
||||
self._dict = d
|
||||
self._lock = lock
|
||||
|
||||
def __setitem__(self, key: str, value: T) -> None:
|
||||
with self._lock:
|
||||
self._dict[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
with self._lock:
|
||||
del self._dict[key]
|
||||
|
||||
def pop(self, key: str, *default: T) -> T: # type: ignore[override]
|
||||
with self._lock:
|
||||
return self._dict.pop(key, *default)
|
||||
|
||||
def update(self, other: dict[str, T]) -> None: # type: ignore[override]
|
||||
with self._lock:
|
||||
self._dict.update(other)
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._dict.clear()
|
||||
|
||||
def setdefault(self, key: str, default: T) -> T: # type: ignore[override]
|
||||
with self._lock:
|
||||
return self._dict.setdefault(key, default)
|
||||
|
||||
def __getitem__(self, key: str) -> T:
|
||||
return self._dict[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dict)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._dict)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in self._dict
|
||||
|
||||
def keys(self) -> KeysView[str]: # type: ignore[override]
|
||||
return self._dict.keys()
|
||||
|
||||
def values(self) -> ValuesView[T]: # type: ignore[override]
|
||||
return self._dict.values()
|
||||
|
||||
def items(self) -> ItemsView[str, T]: # type: ignore[override]
|
||||
return self._dict.items()
|
||||
|
||||
def get(self, key: str, default: T | None = None) -> T | None: # type: ignore[override]
|
||||
return self._dict.get(key, default)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self._dict)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._dict)
|
||||
|
||||
def copy(self) -> dict[str, T]:
|
||||
return self._dict.copy()
|
||||
|
||||
def __or__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override]
|
||||
return self._dict | other
|
||||
|
||||
def __ror__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override]
|
||||
return other | self._dict
|
||||
|
||||
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: # type: ignore[override]
|
||||
with self._lock:
|
||||
self._dict |= other
|
||||
return self
|
||||
|
||||
def __reversed__(self) -> Iterator[str]:
|
||||
return reversed(self._dict)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare based on the underlying dict contents."""
|
||||
if isinstance(other, LockedDictProxy):
|
||||
# Avoid deadlocks by acquiring locks in a consistent order.
|
||||
first, second = (self, other) if id(self) <= id(other) else (other, self)
|
||||
with first._lock:
|
||||
with second._lock:
|
||||
return first._dict == second._dict
|
||||
with self._lock:
|
||||
return self._dict == other
|
||||
|
||||
def __ne__(self, other: object) -> bool:
|
||||
return not self.__eq__(other)
|
||||
|
||||
|
||||
class StateProxy(Generic[T]):
|
||||
"""Proxy that provides thread-safe access to flow state.
|
||||
|
||||
Wraps state objects (dict or BaseModel) and uses a lock for all write
|
||||
operations to prevent race conditions when parallel listeners modify state.
|
||||
"""
|
||||
|
||||
__slots__ = ("_proxy_lock", "_proxy_state")
|
||||
|
||||
def __init__(self, state: T, lock: threading.Lock) -> None:
|
||||
object.__setattr__(self, "_proxy_state", state)
|
||||
object.__setattr__(self, "_proxy_lock", lock)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
value = getattr(object.__getattribute__(self, "_proxy_state"), name)
|
||||
lock = object.__getattribute__(self, "_proxy_lock")
|
||||
if isinstance(value, list):
|
||||
return LockedListProxy(value, lock)
|
||||
if isinstance(value, dict):
|
||||
return LockedDictProxy(value, lock)
|
||||
return value
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in ("_proxy_state", "_proxy_lock"):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
if isinstance(value, LockedListProxy):
|
||||
value = value._list
|
||||
elif isinstance(value, LockedDictProxy):
|
||||
value = value._dict
|
||||
with object.__getattribute__(self, "_proxy_lock"):
|
||||
setattr(object.__getattribute__(self, "_proxy_state"), name, value)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return object.__getattribute__(self, "_proxy_state")[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
with object.__getattribute__(self, "_proxy_lock"):
|
||||
object.__getattribute__(self, "_proxy_state")[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
with object.__getattribute__(self, "_proxy_lock"):
|
||||
del object.__getattribute__(self, "_proxy_state")[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in object.__getattribute__(self, "_proxy_state")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(object.__getattribute__(self, "_proxy_state"))
|
||||
|
||||
def _unwrap(self) -> T:
|
||||
"""Return the underlying state object."""
|
||||
return cast(T, object.__getattribute__(self, "_proxy_state"))
|
||||
|
||||
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Return state as a dictionary.
|
||||
|
||||
Works for both dict and BaseModel underlying states.
|
||||
"""
|
||||
state = object.__getattribute__(self, "_proxy_state")
|
||||
if isinstance(state, dict):
|
||||
return state
|
||||
result: dict[str, Any] = state.model_dump(*args, **kwargs)
|
||||
return result
|
||||
|
||||
|
||||
class FlowMeta(ModelMetaclass):
|
||||
def __new__(
|
||||
mcs,
|
||||
@@ -1025,7 +721,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
_method_outputs: list[Any] = PrivateAttr(default_factory=list)
|
||||
_definition: FlowDefinition = PrivateAttr()
|
||||
_state_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
_or_listeners_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
_completed_methods: set[FlowMethodName] = PrivateAttr(default_factory=set)
|
||||
_method_call_counts: dict[FlowMethodName, int] = PrivateAttr(default_factory=dict)
|
||||
@@ -1947,7 +1642,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
@property
|
||||
def state(self) -> T:
|
||||
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
|
||||
return cast(T, self._state)
|
||||
|
||||
@property
|
||||
def method_outputs(self) -> list[Any]:
|
||||
|
||||
@@ -7,6 +7,7 @@ flow methods, routing logic, and error handling.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
import time
|
||||
from typing import Any
|
||||
@@ -39,8 +40,6 @@ def _build_executor(**kwargs: Any) -> AgentExecutor:
|
||||
executor._human_feedback_method_outputs = {}
|
||||
executor._input_history = []
|
||||
executor._is_execution_resuming = False
|
||||
import threading
|
||||
executor._state_lock = threading.Lock()
|
||||
executor._or_listeners_lock = threading.Lock()
|
||||
executor._execution_lock = threading.Lock()
|
||||
executor._finalize_lock = threading.Lock()
|
||||
|
||||
@@ -1510,42 +1510,36 @@ def test_conditional_router_events_exclusivity():
|
||||
assert "handle_event_c" not in execution_order
|
||||
|
||||
|
||||
def test_state_consistency_across_parallel_branches():
|
||||
"""Test that state remains consistent when branches execute in parallel.
|
||||
def test_and_join_waits_for_parallel_branches():
|
||||
"""Test that sibling branches complete before a joined listener runs.
|
||||
|
||||
Note: Branches triggered by the same parent execute in parallel for efficiency.
|
||||
Thread-safe state access via StateProxy ensures no race conditions.
|
||||
We check the execution order to ensure the branches execute in parallel.
|
||||
Branches triggered by the same parent execute in parallel for efficiency.
|
||||
Shared state updates are not guaranteed to be atomic, so this test uses a
|
||||
locked local recorder instead of branch state mutation.
|
||||
"""
|
||||
execution_order = []
|
||||
execution_order_lock = threading.Lock()
|
||||
|
||||
def record(method_name: str) -> None:
|
||||
with execution_order_lock:
|
||||
execution_order.append(method_name)
|
||||
|
||||
class StateConsistencyFlow(Flow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.state["counter"] = 0
|
||||
self.state["branch_a_value"] = None
|
||||
self.state["branch_b_value"] = None
|
||||
|
||||
@start()
|
||||
def init(self):
|
||||
execution_order.append("init")
|
||||
self.state["counter"] = 10
|
||||
record("init")
|
||||
|
||||
@listen(init)
|
||||
def branch_a(self):
|
||||
execution_order.append("branch_a")
|
||||
self.state["branch_a_value"] = self.state["counter"]
|
||||
self.state["counter"] += 1
|
||||
record("branch_a")
|
||||
|
||||
@listen(init)
|
||||
def branch_b(self):
|
||||
execution_order.append("branch_b")
|
||||
self.state["branch_b_value"] = self.state["counter"]
|
||||
self.state["counter"] += 5
|
||||
record("branch_b")
|
||||
|
||||
@listen(and_(branch_a, branch_b))
|
||||
def verify_state(self):
|
||||
execution_order.append("verify_state")
|
||||
record("verify_state")
|
||||
|
||||
flow = StateConsistencyFlow()
|
||||
flow.kickoff()
|
||||
@@ -1554,10 +1548,8 @@ def test_state_consistency_across_parallel_branches():
|
||||
assert "branch_b" in execution_order
|
||||
assert "verify_state" in execution_order
|
||||
|
||||
assert flow.state["branch_a_value"] is not None
|
||||
assert flow.state["branch_b_value"] is not None
|
||||
|
||||
assert flow.state["counter"] == 16
|
||||
assert execution_order.index("branch_a") < execution_order.index("verify_state")
|
||||
assert execution_order.index("branch_b") < execution_order.index("verify_state")
|
||||
|
||||
|
||||
def test_deeply_nested_conditions():
|
||||
|
||||
@@ -928,8 +928,6 @@ class TestConversationalFlow:
|
||||
conversational = True
|
||||
|
||||
flow = BareChat()
|
||||
# ``flow.state`` returns a ``StateProxy``; the underlying state is
|
||||
# on ``flow._state``. Both views expose the same chat-shaped fields.
|
||||
assert isinstance(flow._state, ConversationState)
|
||||
assert flow.state.messages == []
|
||||
assert flow.state.current_user_message is None
|
||||
|
||||
@@ -466,7 +466,8 @@ def _run_with_events(flow, inputs=None):
|
||||
|
||||
|
||||
def _state_without_id(flow):
|
||||
snapshot = dict(flow.state.model_dump())
|
||||
state = flow.state
|
||||
snapshot = dict(state if isinstance(state, dict) else state.model_dump())
|
||||
snapshot.pop("id", None)
|
||||
return snapshot
|
||||
|
||||
|
||||
@@ -233,7 +233,7 @@ def test_persistence_with_base_model(tmp_path):
|
||||
assert message.role == "user"
|
||||
assert message.type == "text"
|
||||
assert message.content == "Hello, World!"
|
||||
assert isinstance(flow.state._unwrap(), State)
|
||||
assert isinstance(flow.state, State)
|
||||
|
||||
|
||||
def test_fork_with_restore_from_state_id(tmp_path):
|
||||
|
||||
Reference in New Issue
Block a user