From aadb2d6694bf66b5710a358e81abbea4fdf4bc0f Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 19 Mar 2026 19:37:14 -0400 Subject: [PATCH] refactor: type Flow memory and input_provider fields --- lib/crewai/src/crewai/flow/flow.py | 27 +++++++++++------------ lib/crewai/src/crewai/flow/flow_config.py | 10 ++++----- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 71bd31915..b7e51edfb 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -81,6 +81,7 @@ from crewai.flow.flow_wrappers import ( SimpleFlowCondition, StartMethod, ) +from crewai.flow.input_provider import InputProvider from crewai.flow.persistence.base import FlowPersistence from crewai.flow.types import ( FlowExecutionData, @@ -99,6 +100,8 @@ from crewai.flow.utils import ( is_flow_method_name, is_simple_flow_condition, ) +from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.unified_memory import Memory if TYPE_CHECKING: @@ -501,7 +504,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] def index( self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None - ) -> int: # type: ignore[override] + ) -> int: if stop is None: return self._list.index(value, start) return self._list.index(value, start, stop) @@ -520,13 +523,13 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg] def copy(self) -> list[T]: return self._list.copy() - def __add__(self, other: list[T]) -> list[T]: + def __add__(self, other: list[T]) -> list[T]: # type: ignore[override] return self._list + other def __radd__(self, other: list[T]) -> list[T]: return other + self._list - def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: + def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: # type: ignore[override] with self._lock: self._list += list(other) return self @@ -630,13 +633,13 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg] def copy(self) -> dict[str, T]: return self._dict.copy() - def __or__(self, other: dict[str, T]) -> dict[str, T]: + def __or__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override] return self._dict | other - def __ror__(self, other: dict[str, T]) -> dict[str, T]: + def __ror__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override] return other | self._dict - def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: + def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: # type: ignore[override] with self._lock: self._dict |= other return self @@ -822,10 +825,8 @@ class Flow(Generic[T], metaclass=FlowMeta): name: str | None = None tracing: bool | None = None stream: bool = False - memory: Any = ( - None # Memory | MemoryScope | MemorySlice | None; auto-created if not set - ) - input_provider: Any = None # InputProvider | None; per-flow override for self.ask() + memory: Memory | MemoryScope | MemorySlice | None = None + input_provider: InputProvider | None = None def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: class _FlowGeneric(cls): # type: ignore @@ -904,8 +905,6 @@ class Flow(Generic[T], metaclass=FlowMeta): # Internal flows (RecallFlow, EncodingFlow) set _skip_auto_memory # to avoid creating a wasteful standalone Memory instance. if self.memory is None and not getattr(self, "_skip_auto_memory", False): - from crewai.memory.unified_memory import Memory - self.memory = Memory() # Register all flow-related methods @@ -955,7 +954,7 @@ class Flow(Generic[T], metaclass=FlowMeta): if self.memory is None: raise ValueError("No memory configured for this flow") if isinstance(content, list): - return self.memory.remember_many(content, **kwargs) + return self.memory.remember_many(content, **kwargs) # type: ignore[union-attr] return self.memory.remember(content, **kwargs) def extract_memories(self, content: str) -> list[str]: @@ -2725,7 +2724,7 @@ class Flow(Generic[T], metaclass=FlowMeta): # ── User Input (self.ask) ──────────────────────────────────────── - def _resolve_input_provider(self) -> Any: + def _resolve_input_provider(self) -> InputProvider: """Resolve the input provider using the priority chain. Resolution order: diff --git a/lib/crewai/src/crewai/flow/flow_config.py b/lib/crewai/src/crewai/flow/flow_config.py index a4a6bfbe4..7cb838b42 100644 --- a/lib/crewai/src/crewai/flow/flow_config.py +++ b/lib/crewai/src/crewai/flow/flow_config.py @@ -6,7 +6,7 @@ customize Flow behavior at runtime. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -32,17 +32,17 @@ class FlowConfig: self._input_provider: InputProvider | None = None @property - def hitl_provider(self) -> Any: + def hitl_provider(self) -> HumanFeedbackProvider | None: """Get the configured HITL provider.""" return self._hitl_provider @hitl_provider.setter - def hitl_provider(self, provider: Any) -> None: + def hitl_provider(self, provider: HumanFeedbackProvider | None) -> None: """Set the HITL provider.""" self._hitl_provider = provider @property - def input_provider(self) -> Any: + def input_provider(self) -> InputProvider | None: """Get the configured input provider for ``Flow.ask()``. Returns: @@ -52,7 +52,7 @@ class FlowConfig: return self._input_provider @input_provider.setter - def input_provider(self, provider: Any) -> None: + def input_provider(self, provider: InputProvider | None) -> None: """Set the input provider for ``Flow.ask()``. Args: