mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
refactor: type Flow memory and input_provider fields
This commit is contained in:
@@ -81,6 +81,7 @@ from crewai.flow.flow_wrappers import (
|
|||||||
SimpleFlowCondition,
|
SimpleFlowCondition,
|
||||||
StartMethod,
|
StartMethod,
|
||||||
)
|
)
|
||||||
|
from crewai.flow.input_provider import InputProvider
|
||||||
from crewai.flow.persistence.base import FlowPersistence
|
from crewai.flow.persistence.base import FlowPersistence
|
||||||
from crewai.flow.types import (
|
from crewai.flow.types import (
|
||||||
FlowExecutionData,
|
FlowExecutionData,
|
||||||
@@ -99,6 +100,8 @@ from crewai.flow.utils import (
|
|||||||
is_flow_method_name,
|
is_flow_method_name,
|
||||||
is_simple_flow_condition,
|
is_simple_flow_condition,
|
||||||
)
|
)
|
||||||
|
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||||
|
from crewai.memory.unified_memory import Memory
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -501,7 +504,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
|||||||
|
|
||||||
def index(
|
def index(
|
||||||
self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None
|
self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None
|
||||||
) -> int: # type: ignore[override]
|
) -> int:
|
||||||
if stop is None:
|
if stop is None:
|
||||||
return self._list.index(value, start)
|
return self._list.index(value, start)
|
||||||
return self._list.index(value, start, stop)
|
return self._list.index(value, start, stop)
|
||||||
@@ -520,13 +523,13 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
|||||||
def copy(self) -> list[T]:
|
def copy(self) -> list[T]:
|
||||||
return self._list.copy()
|
return self._list.copy()
|
||||||
|
|
||||||
def __add__(self, other: list[T]) -> list[T]:
|
def __add__(self, other: list[T]) -> list[T]: # type: ignore[override]
|
||||||
return self._list + other
|
return self._list + other
|
||||||
|
|
||||||
def __radd__(self, other: list[T]) -> list[T]:
|
def __radd__(self, other: list[T]) -> list[T]:
|
||||||
return other + self._list
|
return other + self._list
|
||||||
|
|
||||||
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]:
|
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: # type: ignore[override]
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._list += list(other)
|
self._list += list(other)
|
||||||
return self
|
return self
|
||||||
@@ -630,13 +633,13 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
|||||||
def copy(self) -> dict[str, T]:
|
def copy(self) -> dict[str, T]:
|
||||||
return self._dict.copy()
|
return self._dict.copy()
|
||||||
|
|
||||||
def __or__(self, other: dict[str, T]) -> dict[str, T]:
|
def __or__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override]
|
||||||
return self._dict | other
|
return self._dict | other
|
||||||
|
|
||||||
def __ror__(self, other: dict[str, T]) -> dict[str, T]:
|
def __ror__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override]
|
||||||
return other | self._dict
|
return other | self._dict
|
||||||
|
|
||||||
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]:
|
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: # type: ignore[override]
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._dict |= other
|
self._dict |= other
|
||||||
return self
|
return self
|
||||||
@@ -822,10 +825,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
tracing: bool | None = None
|
tracing: bool | None = None
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
memory: Any = (
|
memory: Memory | MemoryScope | MemorySlice | None = None
|
||||||
None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
|
input_provider: InputProvider | None = None
|
||||||
)
|
|
||||||
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
|
|
||||||
|
|
||||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
|
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
|
||||||
class _FlowGeneric(cls): # type: ignore
|
class _FlowGeneric(cls): # type: ignore
|
||||||
@@ -904,8 +905,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
# Internal flows (RecallFlow, EncodingFlow) set _skip_auto_memory
|
# Internal flows (RecallFlow, EncodingFlow) set _skip_auto_memory
|
||||||
# to avoid creating a wasteful standalone Memory instance.
|
# to avoid creating a wasteful standalone Memory instance.
|
||||||
if self.memory is None and not getattr(self, "_skip_auto_memory", False):
|
if self.memory is None and not getattr(self, "_skip_auto_memory", False):
|
||||||
from crewai.memory.unified_memory import Memory
|
|
||||||
|
|
||||||
self.memory = Memory()
|
self.memory = Memory()
|
||||||
|
|
||||||
# Register all flow-related methods
|
# Register all flow-related methods
|
||||||
@@ -955,7 +954,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if self.memory is None:
|
if self.memory is None:
|
||||||
raise ValueError("No memory configured for this flow")
|
raise ValueError("No memory configured for this flow")
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
return self.memory.remember_many(content, **kwargs)
|
return self.memory.remember_many(content, **kwargs) # type: ignore[union-attr]
|
||||||
return self.memory.remember(content, **kwargs)
|
return self.memory.remember(content, **kwargs)
|
||||||
|
|
||||||
def extract_memories(self, content: str) -> list[str]:
|
def extract_memories(self, content: str) -> list[str]:
|
||||||
@@ -2725,7 +2724,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
# ── User Input (self.ask) ────────────────────────────────────────
|
# ── User Input (self.ask) ────────────────────────────────────────
|
||||||
|
|
||||||
def _resolve_input_provider(self) -> Any:
|
def _resolve_input_provider(self) -> InputProvider:
|
||||||
"""Resolve the input provider using the priority chain.
|
"""Resolve the input provider using the priority chain.
|
||||||
|
|
||||||
Resolution order:
|
Resolution order:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ customize Flow behavior at runtime.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -32,17 +32,17 @@ class FlowConfig:
|
|||||||
self._input_provider: InputProvider | None = None
|
self._input_provider: InputProvider | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hitl_provider(self) -> Any:
|
def hitl_provider(self) -> HumanFeedbackProvider | None:
|
||||||
"""Get the configured HITL provider."""
|
"""Get the configured HITL provider."""
|
||||||
return self._hitl_provider
|
return self._hitl_provider
|
||||||
|
|
||||||
@hitl_provider.setter
|
@hitl_provider.setter
|
||||||
def hitl_provider(self, provider: Any) -> None:
|
def hitl_provider(self, provider: HumanFeedbackProvider | None) -> None:
|
||||||
"""Set the HITL provider."""
|
"""Set the HITL provider."""
|
||||||
self._hitl_provider = provider
|
self._hitl_provider = provider
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_provider(self) -> Any:
|
def input_provider(self) -> InputProvider | None:
|
||||||
"""Get the configured input provider for ``Flow.ask()``.
|
"""Get the configured input provider for ``Flow.ask()``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -52,7 +52,7 @@ class FlowConfig:
|
|||||||
return self._input_provider
|
return self._input_provider
|
||||||
|
|
||||||
@input_provider.setter
|
@input_provider.setter
|
||||||
def input_provider(self, provider: Any) -> None:
|
def input_provider(self, provider: InputProvider | None) -> None:
|
||||||
"""Set the input provider for ``Flow.ask()``.
|
"""Set the input provider for ``Flow.ask()``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
Reference in New Issue
Block a user