diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index e8ddc4765..64c4059ad 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -497,6 +497,50 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] def __bool__(self) -> bool: return bool(self._list) + def index(self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None) -> int: # type: ignore[override] + if stop is None: + return self._list.index(value, start) + return self._list.index(value, start, stop) + + def count(self, value: T) -> int: + return self._list.count(value) + + def sort(self, *, key: Any = None, reverse: bool = False) -> None: + with self._lock: + self._list.sort(key=key, reverse=reverse) + + def reverse(self) -> None: + with self._lock: + self._list.reverse() + + def copy(self) -> list[T]: + return self._list.copy() + + def __add__(self, other: list[T]) -> list[T]: + return self._list + other + + def __radd__(self, other: list[T]) -> list[T]: + return other + self._list + + def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: + with self._lock: + self._list += list(other) + return self + + def __mul__(self, n: SupportsIndex) -> list[T]: + return self._list * n + + def __rmul__(self, n: SupportsIndex) -> list[T]: + return self._list * n + + def __imul__(self, n: SupportsIndex) -> LockedListProxy[T]: + with self._lock: + self._list *= n + return self + + def __reversed__(self) -> Iterator[T]: + return reversed(self._list) + def __eq__(self, other: object) -> bool: """Compare based on the underlying list contents.""" if isinstance(other, LockedListProxy): @@ -579,6 +623,23 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg] def __bool__(self) -> bool: return bool(self._dict) + def copy(self) -> dict[str, T]: + return self._dict.copy() + + def __or__(self, other: dict[str, T]) -> dict[str, T]: + return self._dict | other + + def __ror__(self, other: dict[str, T]) -> dict[str, T]: + return other | self._dict + + def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: + with self._lock: + self._dict |= other + return self + + def __reversed__(self) -> Iterator[str]: + return reversed(self._dict) + def __eq__(self, other: object) -> bool: """Compare based on the underlying dict contents.""" if isinstance(other, LockedDictProxy): @@ -620,6 +681,10 @@ class StateProxy(Generic[T]): if name in ("_proxy_state", "_proxy_lock"): object.__setattr__(self, name, value) else: + if isinstance(value, LockedListProxy): + value = value._list + elif isinstance(value, LockedDictProxy): + value = value._dict with object.__getattribute__(self, "_proxy_lock"): setattr(object.__getattribute__(self, "_proxy_state"), name, value) diff --git a/lib/crewai/tests/test_flow.py b/lib/crewai/tests/test_flow.py index ccb08cb0a..f214006aa 100644 --- a/lib/crewai/tests/test_flow.py +++ b/lib/crewai/tests/test_flow.py @@ -1893,3 +1893,163 @@ def test_or_condition_self_listen_fires_once(): flow = OrSelfListenFlow() flow.kickoff() assert call_count == 1 + +class ListState(BaseModel): + items: list = [] + + +class DictState(BaseModel): + data: dict = {} + + +class _ListFlow(Flow[ListState]): + @start() + def populate(self): + self.state.items = [3, 1, 4, 1, 5, 9, 2, 6] + + +class _DictFlow(Flow[DictState]): + @start() + def populate(self): + self.state.data = {"a": 1, "b": 2, "c": 3} + + +def _make_list_flow(): + flow = _ListFlow() + flow.kickoff() + return flow + + +def _make_dict_flow(): + flow = _DictFlow() + flow.kickoff() + return flow + + +def test_locked_list_proxy_index(): + flow = _make_list_flow() + assert flow.state.items.index(4) == 2 + assert flow.state.items.index(1, 2) == 3 + + +def test_locked_list_proxy_index_missing_raises(): + flow = _make_list_flow() + with pytest.raises(ValueError): + flow.state.items.index(999) + + +def test_locked_list_proxy_count(): + flow = _make_list_flow() + assert flow.state.items.count(1) == 2 + assert flow.state.items.count(999) == 0 + + +def test_locked_list_proxy_sort(): + flow = _make_list_flow() + flow.state.items.sort() + assert list(flow.state.items) == [1, 1, 2, 3, 4, 5, 6, 9] + + +def test_locked_list_proxy_sort_reverse(): + flow = _make_list_flow() + flow.state.items.sort(reverse=True) + assert list(flow.state.items) == [9, 6, 5, 4, 3, 2, 1, 1] + + +def test_locked_list_proxy_sort_key(): + flow = _make_list_flow() + flow.state.items.sort(key=lambda x: -x) + assert list(flow.state.items) == [9, 6, 5, 4, 3, 2, 1, 1] + + +def test_locked_list_proxy_reverse(): + flow = _make_list_flow() + original = list(flow.state.items) + flow.state.items.reverse() + assert list(flow.state.items) == list(reversed(original)) + + +def test_locked_list_proxy_copy(): + flow = _make_list_flow() + copied = flow.state.items.copy() + assert copied == [3, 1, 4, 1, 5, 9, 2, 6] + assert isinstance(copied, list) + copied.append(999) + assert 999 not in flow.state.items + + +def test_locked_list_proxy_add(): + flow = _make_list_flow() + result = flow.state.items + [10, 11] + assert result == [3, 1, 4, 1, 5, 9, 2, 6, 10, 11] + assert len(flow.state.items) == 8 + + +def test_locked_list_proxy_radd(): + flow = _make_list_flow() + result = [0] + flow.state.items + assert result[0] == 0 + assert len(result) == 9 + + +def test_locked_list_proxy_iadd(): + flow = _make_list_flow() + flow.state.items += [10] + assert 10 in flow.state.items + # Verify no deadlock: mutations must still work after += + flow.state.items.append(99) + assert 99 in flow.state.items + + +def test_locked_list_proxy_mul(): + flow = _make_list_flow() + result = flow.state.items * 2 + assert len(result) == 16 + + +def test_locked_list_proxy_rmul(): + flow = _make_list_flow() + result = 2 * flow.state.items + assert len(result) == 16 + + +def test_locked_list_proxy_reversed(): + flow = _make_list_flow() + original = list(flow.state.items) + assert list(reversed(flow.state.items)) == list(reversed(original)) + + +def test_locked_dict_proxy_copy(): + flow = _make_dict_flow() + copied = flow.state.data.copy() + assert copied == {"a": 1, "b": 2, "c": 3} + assert isinstance(copied, dict) + copied["z"] = 99 + assert "z" not in flow.state.data + + +def test_locked_dict_proxy_or(): + flow = _make_dict_flow() + result = flow.state.data | {"d": 4} + assert result == {"a": 1, "b": 2, "c": 3, "d": 4} + assert "d" not in flow.state.data + + +def test_locked_dict_proxy_ror(): + flow = _make_dict_flow() + result = {"z": 0} | flow.state.data + assert result == {"z": 0, "a": 1, "b": 2, "c": 3} + + +def test_locked_dict_proxy_ior(): + flow = _make_dict_flow() + flow.state.data |= {"d": 4} + assert flow.state.data["d"] == 4 + # Verify no deadlock: mutations must still work after |= + flow.state.data["e"] = 5 + assert flow.state.data["e"] == 5 + + +def test_locked_dict_proxy_reversed(): + flow = _make_dict_flow() + assert list(reversed(flow.state.data)) == ["c", "b", "a"]