From 28a6b855a28bfef5bb5c1ce591095b45f1f6ffc1 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 19 Feb 2026 17:30:47 -0500 Subject: [PATCH] fix: preserve enum type in router result; improve types --- lib/crewai/src/crewai/flow/flow.py | 124 ++++++++++++++++++----------- 1 file changed, 79 insertions(+), 45 deletions(-) diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 58a59debe..b32321daa 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -10,6 +10,7 @@ import asyncio from collections.abc import ( Callable, ItemsView, + Iterable, Iterator, KeysView, Sequence, @@ -17,6 +18,7 @@ from collections.abc import ( ) from concurrent.futures import Future import copy +import enum import inspect import logging import threading @@ -27,8 +29,10 @@ from typing import ( Generic, Literal, ParamSpec, + SupportsIndex, TypeVar, cast, + overload, ) from uuid import uuid4 @@ -77,7 +81,12 @@ from crewai.flow.flow_wrappers import ( StartMethod, ) from crewai.flow.persistence.base import FlowPersistence -from crewai.flow.types import FlowExecutionData, FlowMethodName, InputHistoryEntry, PendingListenerKey +from crewai.flow.types import ( + FlowExecutionData, + FlowMethodName, + InputHistoryEntry, + PendingListenerKey, +) from crewai.flow.utils import ( _extract_all_methods, _extract_all_methods_recursive, @@ -426,8 +435,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] """ def __init__(self, lst: list[T], lock: threading.Lock) -> None: - # Do NOT call super().__init__() -- we don't want to copy data into - # the builtin list storage. All access goes through self._list. + super().__init__() # empty builtin list; all access goes through self._list self._list = lst self._lock = lock @@ -435,11 +443,11 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] with self._lock: self._list.append(item) - def extend(self, items: list[T]) -> None: + def extend(self, items: Iterable[T]) -> None: with self._lock: self._list.extend(items) - def insert(self, index: int, item: T) -> None: + def insert(self, index: SupportsIndex, item: T) -> None: with self._lock: self._list.insert(index, item) @@ -447,7 +455,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] with self._lock: self._list.remove(item) - def pop(self, index: int = -1) -> T: + def pop(self, index: SupportsIndex = -1) -> T: with self._lock: return self._list.pop(index) @@ -455,15 +463,23 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] with self._lock: self._list.clear() - def __setitem__(self, index: int, value: T) -> None: + @overload + def __setitem__(self, index: SupportsIndex, value: T) -> None: ... + @overload + def __setitem__(self, index: slice, value: Iterable[T]) -> None: ... + def __setitem__(self, index: Any, value: Any) -> None: with self._lock: self._list[index] = value - def __delitem__(self, index: int) -> None: + def __delitem__(self, index: SupportsIndex | slice) -> None: with self._lock: del self._list[index] - def __getitem__(self, index: int) -> T: + @overload + def __getitem__(self, index: SupportsIndex) -> T: ... + @overload + def __getitem__(self, index: slice) -> list[T]: ... + def __getitem__(self, index: Any) -> Any: return self._list[index] def __len__(self) -> int: @@ -481,7 +497,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] def __bool__(self) -> bool: return bool(self._list) - def __eq__(self, other: object) -> bool: # type: ignore[override] + def __eq__(self, other: object) -> bool: """Compare based on the underlying list contents.""" if isinstance(other, LockedListProxy): # Avoid deadlocks by acquiring locks in a consistent order. @@ -492,7 +508,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] with self._lock: return self._list == other - def __ne__(self, other: object) -> bool: # type: ignore[override] + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @@ -505,8 +521,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg] """ def __init__(self, d: dict[str, T], lock: threading.Lock) -> None: - # Do NOT call super().__init__() -- we don't want to copy data into - # the builtin dict storage. All access goes through self._dict. + super().__init__() # empty builtin dict; all access goes through self._dict self._dict = d self._lock = lock @@ -518,11 +533,11 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg] with self._lock: del self._dict[key] - def pop(self, key: str, *default: T) -> T: + def pop(self, key: str, *default: T) -> T: # type: ignore[override] with self._lock: return self._dict.pop(key, *default) - def update(self, other: dict[str, T]) -> None: + def update(self, other: dict[str, T]) -> None: # type: ignore[override] with self._lock: self._dict.update(other) @@ -530,7 +545,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg] with self._lock: self._dict.clear() - def setdefault(self, key: str, default: T) -> T: + def setdefault(self, key: str, default: T) -> T: # type: ignore[override] with self._lock: return self._dict.setdefault(key, default) @@ -546,16 +561,16 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg] def __contains__(self, key: object) -> bool: return key in self._dict - def keys(self) -> KeysView[str]: + def keys(self) -> KeysView[str]: # type: ignore[override] return self._dict.keys() - def values(self) -> ValuesView[T]: + def values(self) -> ValuesView[T]: # type: ignore[override] return self._dict.values() - def items(self) -> ItemsView[str, T]: + def items(self) -> ItemsView[str, T]: # type: ignore[override] return self._dict.items() - def get(self, key: str, default: T | None = None) -> T | None: + def get(self, key: str, default: T | None = None) -> T | None: # type: ignore[override] return self._dict.get(key, default) def __repr__(self) -> str: @@ -564,7 +579,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg] def __bool__(self) -> bool: return bool(self._dict) - def __eq__(self, other: object) -> bool: # type: ignore[override] + def __eq__(self, other: object) -> bool: """Compare based on the underlying dict contents.""" if isinstance(other, LockedDictProxy): # Avoid deadlocks by acquiring locks in a consistent order. @@ -575,7 +590,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg] with self._lock: return self._dict == other - def __ne__(self, other: object) -> bool: # type: ignore[override] + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @@ -737,7 +752,9 @@ class Flow(Generic[T], metaclass=FlowMeta): name: str | None = None tracing: bool | None = None stream: bool = False - memory: Any = None # Memory | MemoryScope | MemorySlice | None; auto-created if not set + memory: Any = ( + None # Memory | MemoryScope | MemorySlice | None; auto-created if not set + ) input_provider: Any = None # InputProvider | None; per-flow override for self.ask() def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: @@ -881,7 +898,8 @@ class Flow(Generic[T], metaclass=FlowMeta): """ if self.memory is None: raise ValueError("No memory configured for this flow") - return self.memory.extract_memories(content) + result: list[str] = self.memory.extract_memories(content) + return result def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool: """Mark an OR listener as fired atomically. @@ -1352,8 +1370,10 @@ class Flow(Generic[T], metaclass=FlowMeta): ValueError: If structured state model lacks 'id' field TypeError: If state is neither BaseModel nor dictionary """ + init_state = self.initial_state + # Handle case where initial_state is None but we have a type parameter - if self.initial_state is None and hasattr(self, "_initial_state_t"): + if init_state is None and hasattr(self, "_initial_state_t"): state_type = self._initial_state_t if isinstance(state_type, type): if issubclass(state_type, FlowState): @@ -1377,12 +1397,12 @@ class Flow(Generic[T], metaclass=FlowMeta): return cast(T, {"id": str(uuid4())}) # Handle case where no initial state is provided - if self.initial_state is None: + if init_state is None: return cast(T, {"id": str(uuid4())}) # Handle case where initial_state is a type (class) - if isinstance(self.initial_state, type): - state_class: type[T] = self.initial_state + if isinstance(init_state, type): + state_class = init_state if issubclass(state_class, FlowState): return state_class() if issubclass(state_class, BaseModel): @@ -1393,19 +1413,19 @@ class Flow(Generic[T], metaclass=FlowMeta): if not getattr(model_instance, "id", None): object.__setattr__(model_instance, "id", str(uuid4())) return model_instance - if self.initial_state is dict: + if init_state is dict: return cast(T, {"id": str(uuid4())}) # Handle dictionary instance case - if isinstance(self.initial_state, dict): - new_state = dict(self.initial_state) # Copy to avoid mutations + if isinstance(init_state, dict): + new_state = dict(init_state) # Copy to avoid mutations if "id" not in new_state: new_state["id"] = str(uuid4()) return cast(T, new_state) # Handle BaseModel instance case - if isinstance(self.initial_state, BaseModel): - model = cast(BaseModel, self.initial_state) + if isinstance(init_state, BaseModel): + model = cast(BaseModel, init_state) if not hasattr(model, "id"): raise ValueError("Flow state model must have an 'id' field") @@ -2277,14 +2297,23 @@ class Flow(Generic[T], metaclass=FlowMeta): router_name, router_input, current_triggering_event_id ) if router_result: # Only add non-None results - router_results.append(FlowMethodName(str(router_result))) + router_result_str = ( + router_result.value + if isinstance(router_result, enum.Enum) + else str(router_result) + ) + router_results.append(FlowMethodName(router_result_str)) # If this was a human_feedback router, map the outcome to the feedback if self.last_human_feedback is not None: - router_result_to_feedback[str(router_result)] = ( + router_result_to_feedback[router_result_str] = ( self.last_human_feedback ) current_trigger = ( - FlowMethodName(str(router_result)) + FlowMethodName( + router_result.value + if isinstance(router_result, enum.Enum) + else str(router_result) + ) if router_result is not None else FlowMethodName("") # Update for next iteration of router chain ) @@ -2701,7 +2730,10 @@ class Flow(Generic[T], metaclass=FlowMeta): return topic ``` """ - from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError + from concurrent.futures import ( + ThreadPoolExecutor, + TimeoutError as FuturesTimeoutError, + ) from datetime import datetime from crewai.events.types.flow_events import ( @@ -2770,14 +2802,16 @@ class Flow(Generic[T], metaclass=FlowMeta): response = None # Record in history - self._input_history.append({ - "message": message, - "response": response, - "method_name": method_name, - "timestamp": datetime.now(), - "metadata": metadata, - "response_metadata": response_metadata, - }) + self._input_history.append( + { + "message": message, + "response": response, + "method_name": method_name, + "timestamp": datetime.now(), + "metadata": metadata, + "response_metadata": response_metadata, + } + ) # Emit input received event crewai_event_bus.emit(