diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index 303330dc6..0ecd8e63a 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -54,7 +54,7 @@ from crewai.events.types.tool_usage_events import ( ToolUsageFinishedEvent, ToolUsageStartedEvent, ) -from crewai.flow.flow import Flow, StateProxy, listen, or_, router, start +from crewai.flow.flow import Flow, listen, or_, router, start from crewai.flow.types import FlowMethodName from crewai.hooks.llm_hooks import ( get_after_llm_call_hooks, @@ -276,11 +276,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): """ return self.llm.supports_stop_words() if self.llm else False - @property - def state(self) -> AgentExecutorState: - """Get thread-safe state proxy.""" - return StateProxy(self._state, self._state_lock) # type: ignore[return-value] - @property # type: ignore[misc] def iterations(self) -> int: """Compatibility property for mixin - returns state iterations.""" diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 19c161ffb..8635e2953 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -24,9 +24,6 @@ from crewai.flow.runtime import ( Flow as RuntimeFlow, FlowMeta, FlowState, - LockedDictProxy, - LockedListProxy, - StateProxy, ) @@ -42,9 +39,6 @@ __all__ = [ "Flow", "FlowMeta", "FlowState", - "LockedDictProxy", - "LockedListProxy", - "StateProxy", "and_", "listen", "or_", diff --git a/lib/crewai/src/crewai/flow/runtime/__init__.py b/lib/crewai/src/crewai/flow/runtime/__init__.py index 4bb67a269..8de5be409 100644 --- a/lib/crewai/src/crewai/flow/runtime/__init__.py +++ b/lib/crewai/src/crewai/flow/runtime/__init__.py @@ -1,8 +1,8 @@ """Flow Runtime: the engine that executes a Flow. -Provides the ``Flow`` class (kickoff/resume/listener dispatch), the -``FlowMeta`` metaclass, and the thread-safe state proxies. Flows -authored with the Python DSL (see ``dsl``) are described by a Flow +Provides the ``Flow`` class (kickoff/resume/listener dispatch) and the +``FlowMeta`` metaclass. Flows authored with the Python DSL (see ``dsl``) +are described by a Flow Structure (see ``flow_definition``) and executed here. """ @@ -11,12 +11,8 @@ from __future__ import annotations import asyncio from collections.abc import ( Callable, - ItemsView, - Iterable, Iterator, - KeysView, Sequence, - ValuesView, ) from concurrent.futures import Future, ThreadPoolExecutor import contextvars @@ -35,10 +31,8 @@ from typing import ( Generic, Literal, ParamSpec, - SupportsIndex, TypeVar, cast, - overload, ) from uuid import uuid4 @@ -383,304 +377,6 @@ R = TypeVar("R") F = TypeVar("F", bound=Callable[..., Any]) -class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] - """Thread-safe proxy for list operations. - - Subclasses ``list`` so that ``isinstance(proxy, list)`` returns True, - which is required by libraries like LanceDB and Pydantic that do strict - type checks. All mutations go through the lock; reads delegate to the - underlying list. - """ - - def __init__(self, lst: list[T], lock: threading.Lock) -> None: - super().__init__() # empty builtin list; all access goes through self._list - self._list = lst - self._lock = lock - - def append(self, item: T) -> None: - with self._lock: - self._list.append(item) - - def extend(self, items: Iterable[T]) -> None: - with self._lock: - self._list.extend(items) - - def insert(self, index: SupportsIndex, item: T) -> None: - with self._lock: - self._list.insert(index, item) - - def remove(self, item: T) -> None: - with self._lock: - self._list.remove(item) - - def pop(self, index: SupportsIndex = -1) -> T: - with self._lock: - return self._list.pop(index) - - def clear(self) -> None: - with self._lock: - self._list.clear() - - @overload - def __setitem__(self, index: SupportsIndex, value: T) -> None: ... - @overload - def __setitem__(self, index: slice, value: Iterable[T]) -> None: ... - def __setitem__(self, index: Any, value: Any) -> None: - with self._lock: - self._list[index] = value - - def __delitem__(self, index: SupportsIndex | slice) -> None: - with self._lock: - del self._list[index] - - @overload - def __getitem__(self, index: SupportsIndex) -> T: ... - @overload - def __getitem__(self, index: slice) -> list[T]: ... - def __getitem__(self, index: Any) -> Any: - return self._list[index] - - def __len__(self) -> int: - return len(self._list) - - def __iter__(self) -> Iterator[T]: - return iter(self._list) - - def __contains__(self, item: object) -> bool: - return item in self._list - - def __repr__(self) -> str: - return repr(self._list) - - def __bool__(self) -> bool: - return bool(self._list) - - def index( - self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None - ) -> int: - if stop is None: - return self._list.index(value, start) - return self._list.index(value, start, stop) - - def count(self, value: T) -> int: - return self._list.count(value) - - def sort(self, *, key: Any = None, reverse: bool = False) -> None: - with self._lock: - self._list.sort(key=key, reverse=reverse) - - def reverse(self) -> None: - with self._lock: - self._list.reverse() - - def copy(self) -> list[T]: - return self._list.copy() - - def __add__(self, other: list[T]) -> list[T]: # type: ignore[override] - return self._list + other - - def __radd__(self, other: list[T]) -> list[T]: - return other + self._list - - def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: # type: ignore[override] - with self._lock: - self._list += list(other) - return self - - def __mul__(self, n: SupportsIndex) -> list[T]: - return self._list * n - - def __rmul__(self, n: SupportsIndex) -> list[T]: - return self._list * n - - def __imul__(self, n: SupportsIndex) -> LockedListProxy[T]: - with self._lock: - self._list *= n - return self - - def __reversed__(self) -> Iterator[T]: - return reversed(self._list) - - def __eq__(self, other: object) -> bool: - """Compare based on the underlying list contents.""" - if isinstance(other, LockedListProxy): - # Avoid deadlocks by acquiring locks in a consistent order. - first, second = (self, other) if id(self) <= id(other) else (other, self) - with first._lock: - with second._lock: - return first._list == second._list - with self._lock: - return self._list == other - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) - - -class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg] - """Thread-safe proxy for dict operations. - - Subclasses ``dict`` so that ``isinstance(proxy, dict)`` returns True, - which is required by libraries like Pydantic that do strict type checks. - All mutations go through the lock; reads delegate to the underlying dict. - """ - - def __init__(self, d: dict[str, T], lock: threading.Lock) -> None: - super().__init__() # empty builtin dict; all access goes through self._dict - self._dict = d - self._lock = lock - - def __setitem__(self, key: str, value: T) -> None: - with self._lock: - self._dict[key] = value - - def __delitem__(self, key: str) -> None: - with self._lock: - del self._dict[key] - - def pop(self, key: str, *default: T) -> T: # type: ignore[override] - with self._lock: - return self._dict.pop(key, *default) - - def update(self, other: dict[str, T]) -> None: # type: ignore[override] - with self._lock: - self._dict.update(other) - - def clear(self) -> None: - with self._lock: - self._dict.clear() - - def setdefault(self, key: str, default: T) -> T: # type: ignore[override] - with self._lock: - return self._dict.setdefault(key, default) - - def __getitem__(self, key: str) -> T: - return self._dict[key] - - def __len__(self) -> int: - return len(self._dict) - - def __iter__(self) -> Iterator[str]: - return iter(self._dict) - - def __contains__(self, key: object) -> bool: - return key in self._dict - - def keys(self) -> KeysView[str]: # type: ignore[override] - return self._dict.keys() - - def values(self) -> ValuesView[T]: # type: ignore[override] - return self._dict.values() - - def items(self) -> ItemsView[str, T]: # type: ignore[override] - return self._dict.items() - - def get(self, key: str, default: T | None = None) -> T | None: # type: ignore[override] - return self._dict.get(key, default) - - def __repr__(self) -> str: - return repr(self._dict) - - def __bool__(self) -> bool: - return bool(self._dict) - - def copy(self) -> dict[str, T]: - return self._dict.copy() - - def __or__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override] - return self._dict | other - - def __ror__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override] - return other | self._dict - - def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: # type: ignore[override] - with self._lock: - self._dict |= other - return self - - def __reversed__(self) -> Iterator[str]: - return reversed(self._dict) - - def __eq__(self, other: object) -> bool: - """Compare based on the underlying dict contents.""" - if isinstance(other, LockedDictProxy): - # Avoid deadlocks by acquiring locks in a consistent order. - first, second = (self, other) if id(self) <= id(other) else (other, self) - with first._lock: - with second._lock: - return first._dict == second._dict - with self._lock: - return self._dict == other - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) - - -class StateProxy(Generic[T]): - """Proxy that provides thread-safe access to flow state. - - Wraps state objects (dict or BaseModel) and uses a lock for all write - operations to prevent race conditions when parallel listeners modify state. - """ - - __slots__ = ("_proxy_lock", "_proxy_state") - - def __init__(self, state: T, lock: threading.Lock) -> None: - object.__setattr__(self, "_proxy_state", state) - object.__setattr__(self, "_proxy_lock", lock) - - def __getattr__(self, name: str) -> Any: - value = getattr(object.__getattribute__(self, "_proxy_state"), name) - lock = object.__getattribute__(self, "_proxy_lock") - if isinstance(value, list): - return LockedListProxy(value, lock) - if isinstance(value, dict): - return LockedDictProxy(value, lock) - return value - - def __setattr__(self, name: str, value: Any) -> None: - if name in ("_proxy_state", "_proxy_lock"): - object.__setattr__(self, name, value) - else: - if isinstance(value, LockedListProxy): - value = value._list - elif isinstance(value, LockedDictProxy): - value = value._dict - with object.__getattribute__(self, "_proxy_lock"): - setattr(object.__getattribute__(self, "_proxy_state"), name, value) - - def __getitem__(self, key: str) -> Any: - return object.__getattribute__(self, "_proxy_state")[key] - - def __setitem__(self, key: str, value: Any) -> None: - with object.__getattribute__(self, "_proxy_lock"): - object.__getattribute__(self, "_proxy_state")[key] = value - - def __delitem__(self, key: str) -> None: - with object.__getattribute__(self, "_proxy_lock"): - del object.__getattribute__(self, "_proxy_state")[key] - - def __contains__(self, key: str) -> bool: - return key in object.__getattribute__(self, "_proxy_state") - - def __repr__(self) -> str: - return repr(object.__getattribute__(self, "_proxy_state")) - - def _unwrap(self) -> T: - """Return the underlying state object.""" - return cast(T, object.__getattribute__(self, "_proxy_state")) - - def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - """Return state as a dictionary. - - Works for both dict and BaseModel underlying states. - """ - state = object.__getattribute__(self, "_proxy_state") - if isinstance(state, dict): - return state - result: dict[str, Any] = state.model_dump(*args, **kwargs) - return result - - class FlowMeta(ModelMetaclass): def __new__( mcs, @@ -1025,7 +721,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): ) _method_outputs: list[Any] = PrivateAttr(default_factory=list) _definition: FlowDefinition = PrivateAttr() - _state_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) _or_listeners_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) _completed_methods: set[FlowMethodName] = PrivateAttr(default_factory=set) _method_call_counts: dict[FlowMethodName, int] = PrivateAttr(default_factory=dict) @@ -1947,7 +1642,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): @property def state(self) -> T: - return StateProxy(self._state, self._state_lock) # type: ignore[return-value] + return cast(T, self._state) @property def method_outputs(self) -> list[Any]: diff --git a/lib/crewai/tests/agents/test_agent_executor.py b/lib/crewai/tests/agents/test_agent_executor.py index 992f7460b..e4de4a484 100644 --- a/lib/crewai/tests/agents/test_agent_executor.py +++ b/lib/crewai/tests/agents/test_agent_executor.py @@ -7,6 +7,7 @@ flow methods, routing logic, and error handling. from __future__ import annotations import asyncio +import threading from types import SimpleNamespace import time from typing import Any @@ -39,8 +40,6 @@ def _build_executor(**kwargs: Any) -> AgentExecutor: executor._human_feedback_method_outputs = {} executor._input_history = [] executor._is_execution_resuming = False - import threading - executor._state_lock = threading.Lock() executor._or_listeners_lock = threading.Lock() executor._execution_lock = threading.Lock() executor._finalize_lock = threading.Lock() diff --git a/lib/crewai/tests/test_flow.py b/lib/crewai/tests/test_flow.py index d0d0045b9..4b8a66671 100644 --- a/lib/crewai/tests/test_flow.py +++ b/lib/crewai/tests/test_flow.py @@ -1510,42 +1510,36 @@ def test_conditional_router_events_exclusivity(): assert "handle_event_c" not in execution_order -def test_state_consistency_across_parallel_branches(): - """Test that state remains consistent when branches execute in parallel. +def test_and_join_waits_for_parallel_branches(): + """Test that sibling branches complete before a joined listener runs. - Note: Branches triggered by the same parent execute in parallel for efficiency. - Thread-safe state access via StateProxy ensures no race conditions. - We check the execution order to ensure the branches execute in parallel. + Branches triggered by the same parent execute in parallel for efficiency. + Shared state updates are not guaranteed to be atomic, so this test uses a + locked local recorder instead of branch state mutation. """ execution_order = [] + execution_order_lock = threading.Lock() + + def record(method_name: str) -> None: + with execution_order_lock: + execution_order.append(method_name) class StateConsistencyFlow(Flow): - def __init__(self): - super().__init__() - self.state["counter"] = 0 - self.state["branch_a_value"] = None - self.state["branch_b_value"] = None - @start() def init(self): - execution_order.append("init") - self.state["counter"] = 10 + record("init") @listen(init) def branch_a(self): - execution_order.append("branch_a") - self.state["branch_a_value"] = self.state["counter"] - self.state["counter"] += 1 + record("branch_a") @listen(init) def branch_b(self): - execution_order.append("branch_b") - self.state["branch_b_value"] = self.state["counter"] - self.state["counter"] += 5 + record("branch_b") @listen(and_(branch_a, branch_b)) def verify_state(self): - execution_order.append("verify_state") + record("verify_state") flow = StateConsistencyFlow() flow.kickoff() @@ -1554,10 +1548,8 @@ def test_state_consistency_across_parallel_branches(): assert "branch_b" in execution_order assert "verify_state" in execution_order - assert flow.state["branch_a_value"] is not None - assert flow.state["branch_b_value"] is not None - - assert flow.state["counter"] == 16 + assert execution_order.index("branch_a") < execution_order.index("verify_state") + assert execution_order.index("branch_b") < execution_order.index("verify_state") def test_deeply_nested_conditions(): diff --git a/lib/crewai/tests/test_flow_conversation.py b/lib/crewai/tests/test_flow_conversation.py index 3fea6b471..d8cc0bd37 100644 --- a/lib/crewai/tests/test_flow_conversation.py +++ b/lib/crewai/tests/test_flow_conversation.py @@ -928,8 +928,6 @@ class TestConversationalFlow: conversational = True flow = BareChat() - # ``flow.state`` returns a ``StateProxy``; the underlying state is - # on ``flow._state``. Both views expose the same chat-shaped fields. assert isinstance(flow._state, ConversationState) assert flow.state.messages == [] assert flow.state.current_user_message is None diff --git a/lib/crewai/tests/test_flow_from_definition.py b/lib/crewai/tests/test_flow_from_definition.py index 16160f3cf..0c822c483 100644 --- a/lib/crewai/tests/test_flow_from_definition.py +++ b/lib/crewai/tests/test_flow_from_definition.py @@ -466,7 +466,8 @@ def _run_with_events(flow, inputs=None): def _state_without_id(flow): - snapshot = dict(flow.state.model_dump()) + state = flow.state + snapshot = dict(state if isinstance(state, dict) else state.model_dump()) snapshot.pop("id", None) return snapshot diff --git a/lib/crewai/tests/test_flow_persistence.py b/lib/crewai/tests/test_flow_persistence.py index e5331f7c0..b405cc64d 100644 --- a/lib/crewai/tests/test_flow_persistence.py +++ b/lib/crewai/tests/test_flow_persistence.py @@ -233,7 +233,7 @@ def test_persistence_with_base_model(tmp_path): assert message.role == "user" assert message.type == "text" assert message.content == "Hello, World!" - assert isinstance(flow.state._unwrap(), State) + assert isinstance(flow.state, State) def test_fork_with_restore_from_state_id(tmp_path):