fix: preserve enum type in router result; improve types

This commit is contained in:
Greyson LaLonde
2026-02-19 17:30:47 -05:00
committed by GitHub
parent d09656664d
commit 28a6b855a2

View File

@@ -10,6 +10,7 @@ import asyncio
from collections.abc import (
Callable,
ItemsView,
Iterable,
Iterator,
KeysView,
Sequence,
@@ -17,6 +18,7 @@ from collections.abc import (
)
from concurrent.futures import Future
import copy
import enum
import inspect
import logging
import threading
@@ -27,8 +29,10 @@ from typing import (
Generic,
Literal,
ParamSpec,
SupportsIndex,
TypeVar,
cast,
overload,
)
from uuid import uuid4
@@ -77,7 +81,12 @@ from crewai.flow.flow_wrappers import (
StartMethod,
)
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData, FlowMethodName, InputHistoryEntry, PendingListenerKey
from crewai.flow.types import (
FlowExecutionData,
FlowMethodName,
InputHistoryEntry,
PendingListenerKey,
)
from crewai.flow.utils import (
_extract_all_methods,
_extract_all_methods_recursive,
@@ -426,8 +435,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
"""
def __init__(self, lst: list[T], lock: threading.Lock) -> None:
# Do NOT call super().__init__() -- we don't want to copy data into
# the builtin list storage. All access goes through self._list.
super().__init__() # empty builtin list; all access goes through self._list
self._list = lst
self._lock = lock
@@ -435,11 +443,11 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
with self._lock:
self._list.append(item)
def extend(self, items: list[T]) -> None:
def extend(self, items: Iterable[T]) -> None:
with self._lock:
self._list.extend(items)
def insert(self, index: int, item: T) -> None:
def insert(self, index: SupportsIndex, item: T) -> None:
with self._lock:
self._list.insert(index, item)
@@ -447,7 +455,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
with self._lock:
self._list.remove(item)
def pop(self, index: int = -1) -> T:
def pop(self, index: SupportsIndex = -1) -> T:
with self._lock:
return self._list.pop(index)
@@ -455,15 +463,23 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
with self._lock:
self._list.clear()
def __setitem__(self, index: int, value: T) -> None:
@overload
def __setitem__(self, index: SupportsIndex, value: T) -> None: ...
@overload
def __setitem__(self, index: slice, value: Iterable[T]) -> None: ...
def __setitem__(self, index: Any, value: Any) -> None:
with self._lock:
self._list[index] = value
def __delitem__(self, index: int) -> None:
def __delitem__(self, index: SupportsIndex | slice) -> None:
with self._lock:
del self._list[index]
def __getitem__(self, index: int) -> T:
@overload
def __getitem__(self, index: SupportsIndex) -> T: ...
@overload
def __getitem__(self, index: slice) -> list[T]: ...
def __getitem__(self, index: Any) -> Any:
return self._list[index]
def __len__(self) -> int:
@@ -481,7 +497,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
def __bool__(self) -> bool:
return bool(self._list)
def __eq__(self, other: object) -> bool: # type: ignore[override]
def __eq__(self, other: object) -> bool:
"""Compare based on the underlying list contents."""
if isinstance(other, LockedListProxy):
# Avoid deadlocks by acquiring locks in a consistent order.
@@ -492,7 +508,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
with self._lock:
return self._list == other
def __ne__(self, other: object) -> bool: # type: ignore[override]
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
@@ -505,8 +521,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
"""
def __init__(self, d: dict[str, T], lock: threading.Lock) -> None:
# Do NOT call super().__init__() -- we don't want to copy data into
# the builtin dict storage. All access goes through self._dict.
super().__init__() # empty builtin dict; all access goes through self._dict
self._dict = d
self._lock = lock
@@ -518,11 +533,11 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
with self._lock:
del self._dict[key]
def pop(self, key: str, *default: T) -> T:
def pop(self, key: str, *default: T) -> T: # type: ignore[override]
with self._lock:
return self._dict.pop(key, *default)
def update(self, other: dict[str, T]) -> None:
def update(self, other: dict[str, T]) -> None: # type: ignore[override]
with self._lock:
self._dict.update(other)
@@ -530,7 +545,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
with self._lock:
self._dict.clear()
def setdefault(self, key: str, default: T) -> T:
def setdefault(self, key: str, default: T) -> T: # type: ignore[override]
with self._lock:
return self._dict.setdefault(key, default)
@@ -546,16 +561,16 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
def __contains__(self, key: object) -> bool:
return key in self._dict
def keys(self) -> KeysView[str]:
def keys(self) -> KeysView[str]: # type: ignore[override]
return self._dict.keys()
def values(self) -> ValuesView[T]:
def values(self) -> ValuesView[T]: # type: ignore[override]
return self._dict.values()
def items(self) -> ItemsView[str, T]:
def items(self) -> ItemsView[str, T]: # type: ignore[override]
return self._dict.items()
def get(self, key: str, default: T | None = None) -> T | None:
def get(self, key: str, default: T | None = None) -> T | None: # type: ignore[override]
return self._dict.get(key, default)
def __repr__(self) -> str:
@@ -564,7 +579,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
def __bool__(self) -> bool:
return bool(self._dict)
def __eq__(self, other: object) -> bool: # type: ignore[override]
def __eq__(self, other: object) -> bool:
"""Compare based on the underlying dict contents."""
if isinstance(other, LockedDictProxy):
# Avoid deadlocks by acquiring locks in a consistent order.
@@ -575,7 +590,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
with self._lock:
return self._dict == other
def __ne__(self, other: object) -> bool: # type: ignore[override]
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
@@ -737,7 +752,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
name: str | None = None
tracing: bool | None = None
stream: bool = False
memory: Any = None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
memory: Any = (
None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
)
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
@@ -881,7 +898,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
"""
if self.memory is None:
raise ValueError("No memory configured for this flow")
return self.memory.extract_memories(content)
result: list[str] = self.memory.extract_memories(content)
return result
def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool:
"""Mark an OR listener as fired atomically.
@@ -1352,8 +1370,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
ValueError: If structured state model lacks 'id' field
TypeError: If state is neither BaseModel nor dictionary
"""
init_state = self.initial_state
# Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_t"):
if init_state is None and hasattr(self, "_initial_state_t"):
state_type = self._initial_state_t
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
@@ -1377,12 +1397,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
return cast(T, {"id": str(uuid4())})
# Handle case where no initial state is provided
if self.initial_state is None:
if init_state is None:
return cast(T, {"id": str(uuid4())})
# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
state_class: type[T] = self.initial_state
if isinstance(init_state, type):
state_class = init_state
if issubclass(state_class, FlowState):
return state_class()
if issubclass(state_class, BaseModel):
@@ -1393,19 +1413,19 @@ class Flow(Generic[T], metaclass=FlowMeta):
if not getattr(model_instance, "id", None):
object.__setattr__(model_instance, "id", str(uuid4()))
return model_instance
if self.initial_state is dict:
if init_state is dict:
return cast(T, {"id": str(uuid4())})
# Handle dictionary instance case
if isinstance(self.initial_state, dict):
new_state = dict(self.initial_state) # Copy to avoid mutations
if isinstance(init_state, dict):
new_state = dict(init_state) # Copy to avoid mutations
if "id" not in new_state:
new_state["id"] = str(uuid4())
return cast(T, new_state)
# Handle BaseModel instance case
if isinstance(self.initial_state, BaseModel):
model = cast(BaseModel, self.initial_state)
if isinstance(init_state, BaseModel):
model = cast(BaseModel, init_state)
if not hasattr(model, "id"):
raise ValueError("Flow state model must have an 'id' field")
@@ -2277,14 +2297,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
router_name, router_input, current_triggering_event_id
)
if router_result: # Only add non-None results
router_results.append(FlowMethodName(str(router_result)))
router_result_str = (
router_result.value
if isinstance(router_result, enum.Enum)
else str(router_result)
)
router_results.append(FlowMethodName(router_result_str))
# If this was a human_feedback router, map the outcome to the feedback
if self.last_human_feedback is not None:
router_result_to_feedback[str(router_result)] = (
router_result_to_feedback[router_result_str] = (
self.last_human_feedback
)
current_trigger = (
FlowMethodName(str(router_result))
FlowMethodName(
router_result.value
if isinstance(router_result, enum.Enum)
else str(router_result)
)
if router_result is not None
else FlowMethodName("") # Update for next iteration of router chain
)
@@ -2701,7 +2730,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
return topic
```
"""
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
from concurrent.futures import (
ThreadPoolExecutor,
TimeoutError as FuturesTimeoutError,
)
from datetime import datetime
from crewai.events.types.flow_events import (
@@ -2770,14 +2802,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
response = None
# Record in history
self._input_history.append({
"message": message,
"response": response,
"method_name": method_name,
"timestamp": datetime.now(),
"metadata": metadata,
"response_metadata": response_metadata,
})
self._input_history.append(
{
"message": message,
"response": response,
"method_name": method_name,
"timestamp": datetime.now(),
"metadata": metadata,
"response_metadata": response_metadata,
}
)
# Emit input received event
crewai_event_bus.emit(