mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: preserve enum type in router result; improve types
This commit is contained in:
@@ -10,6 +10,7 @@ import asyncio
|
|||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
Callable,
|
Callable,
|
||||||
ItemsView,
|
ItemsView,
|
||||||
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
KeysView,
|
KeysView,
|
||||||
Sequence,
|
Sequence,
|
||||||
@@ -17,6 +18,7 @@ from collections.abc import (
|
|||||||
)
|
)
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
import copy
|
import copy
|
||||||
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
@@ -27,8 +29,10 @@ from typing import (
|
|||||||
Generic,
|
Generic,
|
||||||
Literal,
|
Literal,
|
||||||
ParamSpec,
|
ParamSpec,
|
||||||
|
SupportsIndex,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
|
overload,
|
||||||
)
|
)
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@@ -77,7 +81,12 @@ from crewai.flow.flow_wrappers import (
|
|||||||
StartMethod,
|
StartMethod,
|
||||||
)
|
)
|
||||||
from crewai.flow.persistence.base import FlowPersistence
|
from crewai.flow.persistence.base import FlowPersistence
|
||||||
from crewai.flow.types import FlowExecutionData, FlowMethodName, InputHistoryEntry, PendingListenerKey
|
from crewai.flow.types import (
|
||||||
|
FlowExecutionData,
|
||||||
|
FlowMethodName,
|
||||||
|
InputHistoryEntry,
|
||||||
|
PendingListenerKey,
|
||||||
|
)
|
||||||
from crewai.flow.utils import (
|
from crewai.flow.utils import (
|
||||||
_extract_all_methods,
|
_extract_all_methods,
|
||||||
_extract_all_methods_recursive,
|
_extract_all_methods_recursive,
|
||||||
@@ -426,8 +435,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lst: list[T], lock: threading.Lock) -> None:
|
def __init__(self, lst: list[T], lock: threading.Lock) -> None:
|
||||||
# Do NOT call super().__init__() -- we don't want to copy data into
|
super().__init__() # empty builtin list; all access goes through self._list
|
||||||
# the builtin list storage. All access goes through self._list.
|
|
||||||
self._list = lst
|
self._list = lst
|
||||||
self._lock = lock
|
self._lock = lock
|
||||||
|
|
||||||
@@ -435,11 +443,11 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._list.append(item)
|
self._list.append(item)
|
||||||
|
|
||||||
def extend(self, items: list[T]) -> None:
|
def extend(self, items: Iterable[T]) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._list.extend(items)
|
self._list.extend(items)
|
||||||
|
|
||||||
def insert(self, index: int, item: T) -> None:
|
def insert(self, index: SupportsIndex, item: T) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._list.insert(index, item)
|
self._list.insert(index, item)
|
||||||
|
|
||||||
@@ -447,7 +455,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._list.remove(item)
|
self._list.remove(item)
|
||||||
|
|
||||||
def pop(self, index: int = -1) -> T:
|
def pop(self, index: SupportsIndex = -1) -> T:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self._list.pop(index)
|
return self._list.pop(index)
|
||||||
|
|
||||||
@@ -455,15 +463,23 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._list.clear()
|
self._list.clear()
|
||||||
|
|
||||||
def __setitem__(self, index: int, value: T) -> None:
|
@overload
|
||||||
|
def __setitem__(self, index: SupportsIndex, value: T) -> None: ...
|
||||||
|
@overload
|
||||||
|
def __setitem__(self, index: slice, value: Iterable[T]) -> None: ...
|
||||||
|
def __setitem__(self, index: Any, value: Any) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._list[index] = value
|
self._list[index] = value
|
||||||
|
|
||||||
def __delitem__(self, index: int) -> None:
|
def __delitem__(self, index: SupportsIndex | slice) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
del self._list[index]
|
del self._list[index]
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> T:
|
@overload
|
||||||
|
def __getitem__(self, index: SupportsIndex) -> T: ...
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, index: slice) -> list[T]: ...
|
||||||
|
def __getitem__(self, index: Any) -> Any:
|
||||||
return self._list[index]
|
return self._list[index]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@@ -481,7 +497,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
|||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
return bool(self._list)
|
return bool(self._list)
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool: # type: ignore[override]
|
def __eq__(self, other: object) -> bool:
|
||||||
"""Compare based on the underlying list contents."""
|
"""Compare based on the underlying list contents."""
|
||||||
if isinstance(other, LockedListProxy):
|
if isinstance(other, LockedListProxy):
|
||||||
# Avoid deadlocks by acquiring locks in a consistent order.
|
# Avoid deadlocks by acquiring locks in a consistent order.
|
||||||
@@ -492,7 +508,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
return self._list == other
|
return self._list == other
|
||||||
|
|
||||||
def __ne__(self, other: object) -> bool: # type: ignore[override]
|
def __ne__(self, other: object) -> bool:
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
|
||||||
@@ -505,8 +521,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d: dict[str, T], lock: threading.Lock) -> None:
|
def __init__(self, d: dict[str, T], lock: threading.Lock) -> None:
|
||||||
# Do NOT call super().__init__() -- we don't want to copy data into
|
super().__init__() # empty builtin dict; all access goes through self._dict
|
||||||
# the builtin dict storage. All access goes through self._dict.
|
|
||||||
self._dict = d
|
self._dict = d
|
||||||
self._lock = lock
|
self._lock = lock
|
||||||
|
|
||||||
@@ -518,11 +533,11 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
del self._dict[key]
|
del self._dict[key]
|
||||||
|
|
||||||
def pop(self, key: str, *default: T) -> T:
|
def pop(self, key: str, *default: T) -> T: # type: ignore[override]
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self._dict.pop(key, *default)
|
return self._dict.pop(key, *default)
|
||||||
|
|
||||||
def update(self, other: dict[str, T]) -> None:
|
def update(self, other: dict[str, T]) -> None: # type: ignore[override]
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._dict.update(other)
|
self._dict.update(other)
|
||||||
|
|
||||||
@@ -530,7 +545,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._dict.clear()
|
self._dict.clear()
|
||||||
|
|
||||||
def setdefault(self, key: str, default: T) -> T:
|
def setdefault(self, key: str, default: T) -> T: # type: ignore[override]
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self._dict.setdefault(key, default)
|
return self._dict.setdefault(key, default)
|
||||||
|
|
||||||
@@ -546,16 +561,16 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
|||||||
def __contains__(self, key: object) -> bool:
|
def __contains__(self, key: object) -> bool:
|
||||||
return key in self._dict
|
return key in self._dict
|
||||||
|
|
||||||
def keys(self) -> KeysView[str]:
|
def keys(self) -> KeysView[str]: # type: ignore[override]
|
||||||
return self._dict.keys()
|
return self._dict.keys()
|
||||||
|
|
||||||
def values(self) -> ValuesView[T]:
|
def values(self) -> ValuesView[T]: # type: ignore[override]
|
||||||
return self._dict.values()
|
return self._dict.values()
|
||||||
|
|
||||||
def items(self) -> ItemsView[str, T]:
|
def items(self) -> ItemsView[str, T]: # type: ignore[override]
|
||||||
return self._dict.items()
|
return self._dict.items()
|
||||||
|
|
||||||
def get(self, key: str, default: T | None = None) -> T | None:
|
def get(self, key: str, default: T | None = None) -> T | None: # type: ignore[override]
|
||||||
return self._dict.get(key, default)
|
return self._dict.get(key, default)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -564,7 +579,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
|||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
return bool(self._dict)
|
return bool(self._dict)
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool: # type: ignore[override]
|
def __eq__(self, other: object) -> bool:
|
||||||
"""Compare based on the underlying dict contents."""
|
"""Compare based on the underlying dict contents."""
|
||||||
if isinstance(other, LockedDictProxy):
|
if isinstance(other, LockedDictProxy):
|
||||||
# Avoid deadlocks by acquiring locks in a consistent order.
|
# Avoid deadlocks by acquiring locks in a consistent order.
|
||||||
@@ -575,7 +590,7 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
return self._dict == other
|
return self._dict == other
|
||||||
|
|
||||||
def __ne__(self, other: object) -> bool: # type: ignore[override]
|
def __ne__(self, other: object) -> bool:
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
|
||||||
@@ -737,7 +752,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
tracing: bool | None = None
|
tracing: bool | None = None
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
memory: Any = None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
|
memory: Any = (
|
||||||
|
None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
|
||||||
|
)
|
||||||
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
|
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
|
||||||
|
|
||||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
|
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
|
||||||
@@ -881,7 +898,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
"""
|
"""
|
||||||
if self.memory is None:
|
if self.memory is None:
|
||||||
raise ValueError("No memory configured for this flow")
|
raise ValueError("No memory configured for this flow")
|
||||||
return self.memory.extract_memories(content)
|
result: list[str] = self.memory.extract_memories(content)
|
||||||
|
return result
|
||||||
|
|
||||||
def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool:
|
def _mark_or_listener_fired(self, listener_name: FlowMethodName) -> bool:
|
||||||
"""Mark an OR listener as fired atomically.
|
"""Mark an OR listener as fired atomically.
|
||||||
@@ -1352,8 +1370,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
ValueError: If structured state model lacks 'id' field
|
ValueError: If structured state model lacks 'id' field
|
||||||
TypeError: If state is neither BaseModel nor dictionary
|
TypeError: If state is neither BaseModel nor dictionary
|
||||||
"""
|
"""
|
||||||
|
init_state = self.initial_state
|
||||||
|
|
||||||
# Handle case where initial_state is None but we have a type parameter
|
# Handle case where initial_state is None but we have a type parameter
|
||||||
if self.initial_state is None and hasattr(self, "_initial_state_t"):
|
if init_state is None and hasattr(self, "_initial_state_t"):
|
||||||
state_type = self._initial_state_t
|
state_type = self._initial_state_t
|
||||||
if isinstance(state_type, type):
|
if isinstance(state_type, type):
|
||||||
if issubclass(state_type, FlowState):
|
if issubclass(state_type, FlowState):
|
||||||
@@ -1377,12 +1397,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
return cast(T, {"id": str(uuid4())})
|
return cast(T, {"id": str(uuid4())})
|
||||||
|
|
||||||
# Handle case where no initial state is provided
|
# Handle case where no initial state is provided
|
||||||
if self.initial_state is None:
|
if init_state is None:
|
||||||
return cast(T, {"id": str(uuid4())})
|
return cast(T, {"id": str(uuid4())})
|
||||||
|
|
||||||
# Handle case where initial_state is a type (class)
|
# Handle case where initial_state is a type (class)
|
||||||
if isinstance(self.initial_state, type):
|
if isinstance(init_state, type):
|
||||||
state_class: type[T] = self.initial_state
|
state_class = init_state
|
||||||
if issubclass(state_class, FlowState):
|
if issubclass(state_class, FlowState):
|
||||||
return state_class()
|
return state_class()
|
||||||
if issubclass(state_class, BaseModel):
|
if issubclass(state_class, BaseModel):
|
||||||
@@ -1393,19 +1413,19 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if not getattr(model_instance, "id", None):
|
if not getattr(model_instance, "id", None):
|
||||||
object.__setattr__(model_instance, "id", str(uuid4()))
|
object.__setattr__(model_instance, "id", str(uuid4()))
|
||||||
return model_instance
|
return model_instance
|
||||||
if self.initial_state is dict:
|
if init_state is dict:
|
||||||
return cast(T, {"id": str(uuid4())})
|
return cast(T, {"id": str(uuid4())})
|
||||||
|
|
||||||
# Handle dictionary instance case
|
# Handle dictionary instance case
|
||||||
if isinstance(self.initial_state, dict):
|
if isinstance(init_state, dict):
|
||||||
new_state = dict(self.initial_state) # Copy to avoid mutations
|
new_state = dict(init_state) # Copy to avoid mutations
|
||||||
if "id" not in new_state:
|
if "id" not in new_state:
|
||||||
new_state["id"] = str(uuid4())
|
new_state["id"] = str(uuid4())
|
||||||
return cast(T, new_state)
|
return cast(T, new_state)
|
||||||
|
|
||||||
# Handle BaseModel instance case
|
# Handle BaseModel instance case
|
||||||
if isinstance(self.initial_state, BaseModel):
|
if isinstance(init_state, BaseModel):
|
||||||
model = cast(BaseModel, self.initial_state)
|
model = cast(BaseModel, init_state)
|
||||||
if not hasattr(model, "id"):
|
if not hasattr(model, "id"):
|
||||||
raise ValueError("Flow state model must have an 'id' field")
|
raise ValueError("Flow state model must have an 'id' field")
|
||||||
|
|
||||||
@@ -2277,14 +2297,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
router_name, router_input, current_triggering_event_id
|
router_name, router_input, current_triggering_event_id
|
||||||
)
|
)
|
||||||
if router_result: # Only add non-None results
|
if router_result: # Only add non-None results
|
||||||
router_results.append(FlowMethodName(str(router_result)))
|
router_result_str = (
|
||||||
|
router_result.value
|
||||||
|
if isinstance(router_result, enum.Enum)
|
||||||
|
else str(router_result)
|
||||||
|
)
|
||||||
|
router_results.append(FlowMethodName(router_result_str))
|
||||||
# If this was a human_feedback router, map the outcome to the feedback
|
# If this was a human_feedback router, map the outcome to the feedback
|
||||||
if self.last_human_feedback is not None:
|
if self.last_human_feedback is not None:
|
||||||
router_result_to_feedback[str(router_result)] = (
|
router_result_to_feedback[router_result_str] = (
|
||||||
self.last_human_feedback
|
self.last_human_feedback
|
||||||
)
|
)
|
||||||
current_trigger = (
|
current_trigger = (
|
||||||
FlowMethodName(str(router_result))
|
FlowMethodName(
|
||||||
|
router_result.value
|
||||||
|
if isinstance(router_result, enum.Enum)
|
||||||
|
else str(router_result)
|
||||||
|
)
|
||||||
if router_result is not None
|
if router_result is not None
|
||||||
else FlowMethodName("") # Update for next iteration of router chain
|
else FlowMethodName("") # Update for next iteration of router chain
|
||||||
)
|
)
|
||||||
@@ -2701,7 +2730,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
return topic
|
return topic
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
from concurrent.futures import (
|
||||||
|
ThreadPoolExecutor,
|
||||||
|
TimeoutError as FuturesTimeoutError,
|
||||||
|
)
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from crewai.events.types.flow_events import (
|
from crewai.events.types.flow_events import (
|
||||||
@@ -2770,14 +2802,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
# Record in history
|
# Record in history
|
||||||
self._input_history.append({
|
self._input_history.append(
|
||||||
"message": message,
|
{
|
||||||
"response": response,
|
"message": message,
|
||||||
"method_name": method_name,
|
"response": response,
|
||||||
"timestamp": datetime.now(),
|
"method_name": method_name,
|
||||||
"metadata": metadata,
|
"timestamp": datetime.now(),
|
||||||
"response_metadata": response_metadata,
|
"metadata": metadata,
|
||||||
})
|
"response_metadata": response_metadata,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Emit input received event
|
# Emit input received event
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
|
|||||||
Reference in New Issue
Block a user