fix: preserve enum type in router result; improve types

This commit is contained in:
Greyson LaLonde
2026-02-19 17:30:47 -05:00
committed by GitHub
parent d09656664d
commit 28a6b855a2

View File

@@ -10,6 +10,7 @@ import asyncio
from collections.abc import ( from collections.abc import (
Callable, Callable,
ItemsView, ItemsView,
Iterable,
Iterator, Iterator,
KeysView, KeysView,
Sequence, Sequence,
@@ -17,6 +18,7 @@ from collections.abc import (
) )
from concurrent.futures import Future from concurrent.futures import Future
import copy import copy
import enum
import inspect import inspect
import logging import logging
import threading import threading
@@ -27,8 +29,10 @@ from typing import (
Generic, Generic,
Literal, Literal,
ParamSpec, ParamSpec,
SupportsIndex,
TypeVar, TypeVar,
cast, cast,
overload,
) )
from uuid import uuid4 from uuid import uuid4
@@ -77,7 +81,12 @@ from crewai.flow.flow_wrappers import (
StartMethod, StartMethod,
) )
from crewai.flow.persistence.base import FlowPersistence from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData, FlowMethodName, InputHistoryEntry, PendingListenerKey from crewai.flow.types import (
FlowExecutionData,
FlowMethodName,
InputHistoryEntry,
PendingListenerKey,
)
from crewai.flow.utils import ( from crewai.flow.utils import (
_extract_all_methods, _extract_all_methods,
_extract_all_methods_recursive, _extract_all_methods_recursive,
@@ -426,8 +435,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
""" """
def __init__(self, lst: list[T], lock: threading.Lock) -> None: def __init__(self, lst: list[T], lock: threading.Lock) -> None:
# Do NOT call super().__init__() -- we don't want to copy data into super().__init__() # empty builtin list; all access goes through self._list
# the builtin list storage. All access goes through self._list.
self._list = lst self._list = lst
self._lock = lock self._lock = lock
@@ -435,11 +443,11 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
with self._lock: with self._lock:
self._list.append(item) self._list.append(item)
def extend(self, items: list[T]) -> None: def extend(self, items: Iterable[T]) -> None:
with self._lock: with self._lock:
self._list.extend(items) self._list.extend(items)
def insert(self, index: int, item: T) -> None: def insert(self, index: SupportsIndex, item: T) -> None:
with self._lock: with self._lock:
self._list.insert(index, item) self._list.insert(index, item)
@@ -447,7 +455,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
with self._lock: with self._lock:
self._list.remove(item) self._list.remove(item)
def pop(self, index: int = -1) -> T: def pop(self, index: SupportsIndex = -1) -> T:
with self._lock: with self._lock:
return self._list.pop(index) return self._list.pop(index)
@@ -455,15 +463,23 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
with self._lock: with self._lock:
self._list.clear() self._list.clear()
def __setitem__(self, index: int, value: T) -> None: @overload
def __setitem__(self, index: SupportsIndex, value: T) -> None: ...
@overload
def __setitem__(self, index: slice, value: Iterable[T]) -> None: ...
def __setitem__(self, index: Any, value: Any) -> None:
with self._lock: with self._lock:
self._list[index] = value self._list[index] = value
def __delitem__(self, index: int) -> None: def __delitem__(self, index: SupportsIndex | slice) -> None:
with self._lock: with self._lock:
del self._list[index] del self._list[index]
def __getitem__(self, index: int) -> T: @overload
def __getitem__(self, index: SupportsIndex) -> T: ...
@overload
def __getitem__(self, index: slice) -> list[T]: ...
def __getitem__(self, index: Any) -> Any:
return self._list[index] return self._list[index]
def __len__(self) -> int: def __len__(self) -> int:
@@ -481,7 +497,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
def __bool__(self) -> bool: def __bool__(self) -> bool:
return bool(self._list) return bool(self._list)
def __eq__(self, other: object) -> bool: # type: ignore[override] def __eq__(self, other: object) -> bool:
"""Compare based on the underlying list contents.""" """Compare based on the underlying list contents."""
if isinstance(other, LockedListProxy): if isinstance(other, LockedListProxy):
# Avoid deadlocks by acquiring locks in a consistent order. # Avoid deadlocks by acquiring locks in a consistent order.
@@ -492,7 +508,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
with self._lock: with self._lock:
return self._list == other return self._list == other
def __ne__(self, other: object) -> bool: # type: ignore[override] def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
@@ -505,8 +521,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
""" """
def __init__(self, d: dict[str, T], lock: threading.Lock) -> None: def __init__(self, d: dict[str, T], lock: threading.Lock) -> None:
# Do NOT call super().__init__() -- we don't want to copy data into super().__init__() # empty builtin dict; all access goes through self._dict
# the builtin dict storage. All access goes through self._dict.
self._dict = d self._dict = d
self._lock = lock self._lock = lock
@@ -518,11 +533,11 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
with self._lock: with self._lock:
del self._dict[key] del self._dict[key]
def pop(self, key: str, *default: T) -> T: def pop(self, key: str, *default: T) -> T: # type: ignore[override]
with self._lock: with self._lock:
return self._dict.pop(key, *default) return self._dict.pop(key, *default)
def update(self, other: dict[str, T]) -> None: def update(self, other: dict[str, T]) -> None: # type: ignore[override]
with self._lock: with self._lock:
self._dict.update(other) self._dict.update(other)
@@ -530,7 +545,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
with self._lock: with self._lock:
self._dict.clear() self._dict.clear()
def setdefault(self, key: str, default: T) -> T: def setdefault(self, key: str, default: T) -> T: # type: ignore[override]
with self._lock: with self._lock:
return self._dict.setdefault(key, default) return self._dict.setdefault(key, default)
@@ -546,16 +561,16 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
def __contains__(self, key: object) -> bool: def __contains__(self, key: object) -> bool:
return key in self._dict return key in self._dict
def keys(self) -> KeysView[str]: def keys(self) -> KeysView[str]: # type: ignore[override]
return self._dict.keys() return self._dict.keys()
def values(self) -> ValuesView[T]: def values(self) -> ValuesView[T]: # type: ignore[override]
return self._dict.values() return self._dict.values()
def items(self) -> ItemsView[str, T]: def items(self) -> ItemsView[str, T]: # type: ignore[override]
return self._dict.items() return self._dict.items()
def get(self, key: str, default: T | None = None) -> T | None: def get(self, key: str, default: T | None = None) -> T | None: # type: ignore[override]
return self._dict.get(key, default) return self._dict.get(key, default)
def __repr__(self) -> str: def __repr__(self) -> str:
@@ -564,7 +579,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
def __bool__(self) -> bool: def __bool__(self) -> bool:
return bool(self._dict) return bool(self._dict)
def __eq__(self, other: object) -> bool: # type: ignore[override] def __eq__(self, other: object) -> bool:
"""Compare based on the underlying dict contents.""" """Compare based on the underlying dict contents."""
if isinstance(other, LockedDictProxy): if isinstance(other, LockedDictProxy):
# Avoid deadlocks by acquiring locks in a consistent order. # Avoid deadlocks by acquiring locks in a consistent order.
@@ -575,7 +590,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
with self._lock: with self._lock:
return self._dict == other return self._dict == other
def __ne__(self, other: object) -> bool: # type: ignore[override] def __ne__(self, other: object) -> bool:
return not self.__eq__(other) return not self.__eq__(other)
@@ -737,7 +752,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
name: str | None = None name: str | None = None
tracing: bool | None = None tracing: bool | None = None
stream: bool = False stream: bool = False
memory: Any = None # Memory | MemoryScope | MemorySlice | None; auto-created if not set memory: Any = (
None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
)
input_provider: Any = None # InputProvider | None; per-flow override for self.ask() input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
@@ -881,7 +898,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
""" """
if self.memory is None: if self.memory is None:
raise ValueError("No memory configured for this flow") raise ValueError("No memory configured for this flow")
return self.memory.extract_memories(content) result: list[str] = self.memory.extract_memories(content)
return result
def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool: def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool:
"""Mark an OR listener as fired atomically. """Mark an OR listener as fired atomically.
@@ -1352,8 +1370,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
ValueError: If structured state model lacks 'id' field ValueError: If structured state model lacks 'id' field
TypeError: If state is neither BaseModel nor dictionary TypeError: If state is neither BaseModel nor dictionary
""" """
init_state = self.initial_state
# Handle case where initial_state is None but we have a type parameter # Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_t"): if init_state is None and hasattr(self, "_initial_state_t"):
state_type = self._initial_state_t state_type = self._initial_state_t
if isinstance(state_type, type): if isinstance(state_type, type):
if issubclass(state_type, FlowState): if issubclass(state_type, FlowState):
@@ -1377,12 +1397,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
return cast(T, {"id": str(uuid4())}) return cast(T, {"id": str(uuid4())})
# Handle case where no initial state is provided # Handle case where no initial state is provided
if self.initial_state is None: if init_state is None:
return cast(T, {"id": str(uuid4())}) return cast(T, {"id": str(uuid4())})
# Handle case where initial_state is a type (class) # Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type): if isinstance(init_state, type):
state_class: type[T] = self.initial_state state_class = init_state
if issubclass(state_class, FlowState): if issubclass(state_class, FlowState):
return state_class() return state_class()
if issubclass(state_class, BaseModel): if issubclass(state_class, BaseModel):
@@ -1393,19 +1413,19 @@ class Flow(Generic[T], metaclass=FlowMeta):
if not getattr(model_instance, "id", None): if not getattr(model_instance, "id", None):
object.__setattr__(model_instance, "id", str(uuid4())) object.__setattr__(model_instance, "id", str(uuid4()))
return model_instance return model_instance
if self.initial_state is dict: if init_state is dict:
return cast(T, {"id": str(uuid4())}) return cast(T, {"id": str(uuid4())})
# Handle dictionary instance case # Handle dictionary instance case
if isinstance(self.initial_state, dict): if isinstance(init_state, dict):
new_state = dict(self.initial_state) # Copy to avoid mutations new_state = dict(init_state) # Copy to avoid mutations
if "id" not in new_state: if "id" not in new_state:
new_state["id"] = str(uuid4()) new_state["id"] = str(uuid4())
return cast(T, new_state) return cast(T, new_state)
# Handle BaseModel instance case # Handle BaseModel instance case
if isinstance(self.initial_state, BaseModel): if isinstance(init_state, BaseModel):
model = cast(BaseModel, self.initial_state) model = cast(BaseModel, init_state)
if not hasattr(model, "id"): if not hasattr(model, "id"):
raise ValueError("Flow state model must have an 'id' field") raise ValueError("Flow state model must have an 'id' field")
@@ -2277,14 +2297,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
router_name, router_input, current_triggering_event_id router_name, router_input, current_triggering_event_id
) )
if router_result: # Only add non-None results if router_result: # Only add non-None results
router_results.append(FlowMethodName(str(router_result))) router_result_str = (
router_result.value
if isinstance(router_result, enum.Enum)
else str(router_result)
)
router_results.append(FlowMethodName(router_result_str))
# If this was a human_feedback router, map the outcome to the feedback # If this was a human_feedback router, map the outcome to the feedback
if self.last_human_feedback is not None: if self.last_human_feedback is not None:
router_result_to_feedback[str(router_result)] = ( router_result_to_feedback[router_result_str] = (
self.last_human_feedback self.last_human_feedback
) )
current_trigger = ( current_trigger = (
FlowMethodName(str(router_result)) FlowMethodName(
router_result.value
if isinstance(router_result, enum.Enum)
else str(router_result)
)
if router_result is not None if router_result is not None
else FlowMethodName("") # Update for next iteration of router chain else FlowMethodName("") # Update for next iteration of router chain
) )
@@ -2701,7 +2730,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
return topic return topic
``` ```
""" """
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError from concurrent.futures import (
ThreadPoolExecutor,
TimeoutError as FuturesTimeoutError,
)
from datetime import datetime from datetime import datetime
from crewai.events.types.flow_events import ( from crewai.events.types.flow_events import (
@@ -2770,14 +2802,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
response = None response = None
# Record in history # Record in history
self._input_history.append({ self._input_history.append(
"message": message, {
"response": response, "message": message,
"method_name": method_name, "response": response,
"timestamp": datetime.now(), "method_name": method_name,
"metadata": metadata, "timestamp": datetime.now(),
"response_metadata": response_metadata, "metadata": metadata,
}) "response_metadata": response_metadata,
}
)
# Emit input received event # Emit input received event
crewai_event_bus.emit( crewai_event_bus.emit(