Remove StateProxy from flow state access (#6327)

`StateProxy` looked like a thread-safety boundary, but it only protected
a small slice of state operations. Some examples of operations that were
not covered:

- `self.state.counter += 1`, `self.state["counter"] += 1` (increments)
- `self.state.user.profile.score += 1` (nested object mutations)
- `self.state.config["limits"]["max"] = 10` (mutation through model fields)
- `self.state.items[0].status = "done"` (list/container mutations)

This commit decided to remove it completely for simplicity and
performance:

- Simpler runtime code
- attr read: 24x faster, attr write: 27x faster, list append: 19x faster (local benchmark)
- Clearer concurrency contract (lifecycle locks remain, but arbitrary
  shared state mutation is not presented as thread-safe)
This commit is contained in:
Vinicius Brasil
2026-06-24 16:37:51 -07:00
committed by GitHub
parent 7738a1d30c
commit 340d23ae5d
8 changed files with 25 additions and 351 deletions

View File

@@ -54,7 +54,7 @@ from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.flow.flow import Flow, StateProxy, listen, or_, router, start
from crewai.flow.flow import Flow, listen, or_, router, start
from crewai.flow.types import FlowMethodName
from crewai.hooks.llm_hooks import (
get_after_llm_call_hooks,
@@ -276,11 +276,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
"""
return self.llm.supports_stop_words() if self.llm else False
@property
def state(self) -> AgentExecutorState:
"""Get thread-safe state proxy."""
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
@property # type: ignore[misc]
def iterations(self) -> int:
"""Compatibility property for mixin - returns state iterations."""

View File

@@ -24,9 +24,6 @@ from crewai.flow.runtime import (
Flow as RuntimeFlow,
FlowMeta,
FlowState,
LockedDictProxy,
LockedListProxy,
StateProxy,
)
@@ -42,9 +39,6 @@ __all__ = [
"Flow",
"FlowMeta",
"FlowState",
"LockedDictProxy",
"LockedListProxy",
"StateProxy",
"and_",
"listen",
"or_",

View File

@@ -1,8 +1,8 @@
"""Flow Runtime: the engine that executes a Flow.
Provides the ``Flow`` class (kickoff/resume/listener dispatch), the
``FlowMeta`` metaclass, and the thread-safe state proxies. Flows
authored with the Python DSL (see ``dsl``) are described by a Flow
Provides the ``Flow`` class (kickoff/resume/listener dispatch) and the
``FlowMeta`` metaclass. Flows authored with the Python DSL (see ``dsl``)
are described by a Flow
Structure (see ``flow_definition``) and executed here.
"""
@@ -11,12 +11,8 @@ from __future__ import annotations
import asyncio
from collections.abc import (
Callable,
ItemsView,
Iterable,
Iterator,
KeysView,
Sequence,
ValuesView,
)
from concurrent.futures import Future, ThreadPoolExecutor
import contextvars
@@ -35,10 +31,8 @@ from typing import (
Generic,
Literal,
ParamSpec,
SupportsIndex,
TypeVar,
cast,
overload,
)
from uuid import uuid4
@@ -383,304 +377,6 @@ R = TypeVar("R")
F = TypeVar("F", bound=Callable[..., Any])
class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
"""Thread-safe proxy for list operations.
Subclasses ``list`` so that ``isinstance(proxy, list)`` returns True,
which is required by libraries like LanceDB and Pydantic that do strict
type checks. All mutations go through the lock; reads delegate to the
underlying list.
"""
def __init__(self, lst: list[T], lock: threading.Lock) -> None:
super().__init__() # empty builtin list; all access goes through self._list
self._list = lst
self._lock = lock
def append(self, item: T) -> None:
with self._lock:
self._list.append(item)
def extend(self, items: Iterable[T]) -> None:
with self._lock:
self._list.extend(items)
def insert(self, index: SupportsIndex, item: T) -> None:
with self._lock:
self._list.insert(index, item)
def remove(self, item: T) -> None:
with self._lock:
self._list.remove(item)
def pop(self, index: SupportsIndex = -1) -> T:
with self._lock:
return self._list.pop(index)
def clear(self) -> None:
with self._lock:
self._list.clear()
@overload
def __setitem__(self, index: SupportsIndex, value: T) -> None: ...
@overload
def __setitem__(self, index: slice, value: Iterable[T]) -> None: ...
def __setitem__(self, index: Any, value: Any) -> None:
with self._lock:
self._list[index] = value
def __delitem__(self, index: SupportsIndex | slice) -> None:
with self._lock:
del self._list[index]
@overload
def __getitem__(self, index: SupportsIndex) -> T: ...
@overload
def __getitem__(self, index: slice) -> list[T]: ...
def __getitem__(self, index: Any) -> Any:
return self._list[index]
def __len__(self) -> int:
return len(self._list)
def __iter__(self) -> Iterator[T]:
return iter(self._list)
def __contains__(self, item: object) -> bool:
return item in self._list
def __repr__(self) -> str:
return repr(self._list)
def __bool__(self) -> bool:
return bool(self._list)
def index(
self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None
) -> int:
if stop is None:
return self._list.index(value, start)
return self._list.index(value, start, stop)
def count(self, value: T) -> int:
return self._list.count(value)
def sort(self, *, key: Any = None, reverse: bool = False) -> None:
with self._lock:
self._list.sort(key=key, reverse=reverse)
def reverse(self) -> None:
with self._lock:
self._list.reverse()
def copy(self) -> list[T]:
return self._list.copy()
def __add__(self, other: list[T]) -> list[T]: # type: ignore[override]
return self._list + other
def __radd__(self, other: list[T]) -> list[T]:
return other + self._list
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: # type: ignore[override]
with self._lock:
self._list += list(other)
return self
def __mul__(self, n: SupportsIndex) -> list[T]:
return self._list * n
def __rmul__(self, n: SupportsIndex) -> list[T]:
return self._list * n
def __imul__(self, n: SupportsIndex) -> LockedListProxy[T]:
with self._lock:
self._list *= n
return self
def __reversed__(self) -> Iterator[T]:
return reversed(self._list)
def __eq__(self, other: object) -> bool:
"""Compare based on the underlying list contents."""
if isinstance(other, LockedListProxy):
# Avoid deadlocks by acquiring locks in a consistent order.
first, second = (self, other) if id(self) <= id(other) else (other, self)
with first._lock:
with second._lock:
return first._list == second._list
with self._lock:
return self._list == other
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
"""Thread-safe proxy for dict operations.
Subclasses ``dict`` so that ``isinstance(proxy, dict)`` returns True,
which is required by libraries like Pydantic that do strict type checks.
All mutations go through the lock; reads delegate to the underlying dict.
"""
def __init__(self, d: dict[str, T], lock: threading.Lock) -> None:
super().__init__() # empty builtin dict; all access goes through self._dict
self._dict = d
self._lock = lock
def __setitem__(self, key: str, value: T) -> None:
with self._lock:
self._dict[key] = value
def __delitem__(self, key: str) -> None:
with self._lock:
del self._dict[key]
def pop(self, key: str, *default: T) -> T: # type: ignore[override]
with self._lock:
return self._dict.pop(key, *default)
def update(self, other: dict[str, T]) -> None: # type: ignore[override]
with self._lock:
self._dict.update(other)
def clear(self) -> None:
with self._lock:
self._dict.clear()
def setdefault(self, key: str, default: T) -> T: # type: ignore[override]
with self._lock:
return self._dict.setdefault(key, default)
def __getitem__(self, key: str) -> T:
return self._dict[key]
def __len__(self) -> int:
return len(self._dict)
def __iter__(self) -> Iterator[str]:
return iter(self._dict)
def __contains__(self, key: object) -> bool:
return key in self._dict
def keys(self) -> KeysView[str]: # type: ignore[override]
return self._dict.keys()
def values(self) -> ValuesView[T]: # type: ignore[override]
return self._dict.values()
def items(self) -> ItemsView[str, T]: # type: ignore[override]
return self._dict.items()
def get(self, key: str, default: T | None = None) -> T | None: # type: ignore[override]
return self._dict.get(key, default)
def __repr__(self) -> str:
return repr(self._dict)
def __bool__(self) -> bool:
return bool(self._dict)
def copy(self) -> dict[str, T]:
return self._dict.copy()
def __or__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override]
return self._dict | other
def __ror__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override]
return other | self._dict
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: # type: ignore[override]
with self._lock:
self._dict |= other
return self
def __reversed__(self) -> Iterator[str]:
return reversed(self._dict)
def __eq__(self, other: object) -> bool:
"""Compare based on the underlying dict contents."""
if isinstance(other, LockedDictProxy):
# Avoid deadlocks by acquiring locks in a consistent order.
first, second = (self, other) if id(self) <= id(other) else (other, self)
with first._lock:
with second._lock:
return first._dict == second._dict
with self._lock:
return self._dict == other
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
class StateProxy(Generic[T]):
"""Proxy that provides thread-safe access to flow state.
Wraps state objects (dict or BaseModel) and uses a lock for all write
operations to prevent race conditions when parallel listeners modify state.
"""
__slots__ = ("_proxy_lock", "_proxy_state")
def __init__(self, state: T, lock: threading.Lock) -> None:
object.__setattr__(self, "_proxy_state", state)
object.__setattr__(self, "_proxy_lock", lock)
def __getattr__(self, name: str) -> Any:
value = getattr(object.__getattribute__(self, "_proxy_state"), name)
lock = object.__getattribute__(self, "_proxy_lock")
if isinstance(value, list):
return LockedListProxy(value, lock)
if isinstance(value, dict):
return LockedDictProxy(value, lock)
return value
def __setattr__(self, name: str, value: Any) -> None:
if name in ("_proxy_state", "_proxy_lock"):
object.__setattr__(self, name, value)
else:
if isinstance(value, LockedListProxy):
value = value._list
elif isinstance(value, LockedDictProxy):
value = value._dict
with object.__getattribute__(self, "_proxy_lock"):
setattr(object.__getattribute__(self, "_proxy_state"), name, value)
def __getitem__(self, key: str) -> Any:
return object.__getattribute__(self, "_proxy_state")[key]
def __setitem__(self, key: str, value: Any) -> None:
with object.__getattribute__(self, "_proxy_lock"):
object.__getattribute__(self, "_proxy_state")[key] = value
def __delitem__(self, key: str) -> None:
with object.__getattribute__(self, "_proxy_lock"):
del object.__getattribute__(self, "_proxy_state")[key]
def __contains__(self, key: str) -> bool:
return key in object.__getattribute__(self, "_proxy_state")
def __repr__(self) -> str:
return repr(object.__getattribute__(self, "_proxy_state"))
def _unwrap(self) -> T:
"""Return the underlying state object."""
return cast(T, object.__getattribute__(self, "_proxy_state"))
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
"""Return state as a dictionary.
Works for both dict and BaseModel underlying states.
"""
state = object.__getattribute__(self, "_proxy_state")
if isinstance(state, dict):
return state
result: dict[str, Any] = state.model_dump(*args, **kwargs)
return result
class FlowMeta(ModelMetaclass):
def __new__(
mcs,
@@ -1025,7 +721,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
)
_method_outputs: list[Any] = PrivateAttr(default_factory=list)
_definition: FlowDefinition = PrivateAttr()
_state_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_or_listeners_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_completed_methods: set[FlowMethodName] = PrivateAttr(default_factory=set)
_method_call_counts: dict[FlowMethodName, int] = PrivateAttr(default_factory=dict)
@@ -1947,7 +1642,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
@property
def state(self) -> T:
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
return cast(T, self._state)
@property
def method_outputs(self) -> list[Any]:

View File

@@ -7,6 +7,7 @@ flow methods, routing logic, and error handling.
from __future__ import annotations
import asyncio
import threading
from types import SimpleNamespace
import time
from typing import Any
@@ -39,8 +40,6 @@ def _build_executor(**kwargs: Any) -> AgentExecutor:
executor._human_feedback_method_outputs = {}
executor._input_history = []
executor._is_execution_resuming = False
import threading
executor._state_lock = threading.Lock()
executor._or_listeners_lock = threading.Lock()
executor._execution_lock = threading.Lock()
executor._finalize_lock = threading.Lock()

View File

@@ -1510,42 +1510,36 @@ def test_conditional_router_events_exclusivity():
assert "handle_event_c" not in execution_order
def test_state_consistency_across_parallel_branches():
"""Test that state remains consistent when branches execute in parallel.
def test_and_join_waits_for_parallel_branches():
"""Test that sibling branches complete before a joined listener runs.
Note: Branches triggered by the same parent execute in parallel for efficiency.
Thread-safe state access via StateProxy ensures no race conditions.
We check the execution order to ensure the branches execute in parallel.
Branches triggered by the same parent execute in parallel for efficiency.
Shared state updates are not guaranteed to be atomic, so this test uses a
locked local recorder instead of branch state mutation.
"""
execution_order = []
execution_order_lock = threading.Lock()
def record(method_name: str) -> None:
with execution_order_lock:
execution_order.append(method_name)
class StateConsistencyFlow(Flow):
def __init__(self):
super().__init__()
self.state["counter"] = 0
self.state["branch_a_value"] = None
self.state["branch_b_value"] = None
@start()
def init(self):
execution_order.append("init")
self.state["counter"] = 10
record("init")
@listen(init)
def branch_a(self):
execution_order.append("branch_a")
self.state["branch_a_value"] = self.state["counter"]
self.state["counter"] += 1
record("branch_a")
@listen(init)
def branch_b(self):
execution_order.append("branch_b")
self.state["branch_b_value"] = self.state["counter"]
self.state["counter"] += 5
record("branch_b")
@listen(and_(branch_a, branch_b))
def verify_state(self):
execution_order.append("verify_state")
record("verify_state")
flow = StateConsistencyFlow()
flow.kickoff()
@@ -1554,10 +1548,8 @@ def test_state_consistency_across_parallel_branches():
assert "branch_b" in execution_order
assert "verify_state" in execution_order
assert flow.state["branch_a_value"] is not None
assert flow.state["branch_b_value"] is not None
assert flow.state["counter"] == 16
assert execution_order.index("branch_a") < execution_order.index("verify_state")
assert execution_order.index("branch_b") < execution_order.index("verify_state")
def test_deeply_nested_conditions():

View File

@@ -928,8 +928,6 @@ class TestConversationalFlow:
conversational = True
flow = BareChat()
# ``flow.state`` returns a ``StateProxy``; the underlying state is
# on ``flow._state``. Both views expose the same chat-shaped fields.
assert isinstance(flow._state, ConversationState)
assert flow.state.messages == []
assert flow.state.current_user_message is None

View File

@@ -466,7 +466,8 @@ def _run_with_events(flow, inputs=None):
def _state_without_id(flow):
snapshot = dict(flow.state.model_dump())
state = flow.state
snapshot = dict(state if isinstance(state, dict) else state.model_dump())
snapshot.pop("id", None)
return snapshot

View File

@@ -233,7 +233,7 @@ def test_persistence_with_base_model(tmp_path):
assert message.role == "user"
assert message.type == "text"
assert message.content == "Hello, World!"
assert isinstance(flow.state._unwrap(), State)
assert isinstance(flow.state, State)
def test_fork_with_restore_from_state_id(tmp_path):