From cd42bcf035c7e5b50ca6317da712d99394c75b44 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Sun, 8 Mar 2026 23:08:10 -0400 Subject: [PATCH] refactor(memory): convert memory classes to serializable * refactor(memory): convert Memory, MemoryScope, and MemorySlice to BaseModel * fix(test): update mock memory attribute from _read_only to read_only * fix: handle re-validation in wrap validators and patch BaseModel class in tests --- .../base_agent_executor_mixin.py | 11 +- lib/crewai/src/crewai/lite_agent.py | 2 +- lib/crewai/src/crewai/memory/memory_scope.py | 116 +++++---- .../src/crewai/memory/unified_memory.py | 228 +++++++++--------- lib/crewai/src/crewai/tools/memory_tools.py | 2 +- lib/crewai/tests/agents/test_lite_agent.py | 2 +- .../tests/memory/test_unified_memory.py | 26 +- lib/crewai/tests/test_crew.py | 12 +- 8 files changed, 211 insertions(+), 188 deletions(-) diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py index 1abfb6e5a..9dd1e2396 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py @@ -30,12 +30,9 @@ class CrewAgentExecutorMixin: memory = getattr(self.agent, "memory", None) or ( getattr(self.crew, "_memory", None) if self.crew else None ) - if memory is None or not self.task or getattr(memory, "_read_only", False): + if memory is None or not self.task or memory.read_only: return - if ( - f"Action: {sanitize_tool_name('Delegate work to coworker')}" - in output.text - ): + if f"Action: {sanitize_tool_name('Delegate work to coworker')}" in output.text: return try: raw = ( @@ -48,6 +45,4 @@ class CrewAgentExecutorMixin: if extracted: memory.remember_many(extracted, agent_role=self.agent.role) except Exception as e: - self.agent._logger.log( - "error", f"Failed to save to memory: {e}" - ) + self.agent._logger.log("error", f"Failed to save to memory: {e}") diff --git a/lib/crewai/src/crewai/lite_agent.py b/lib/crewai/src/crewai/lite_agent.py index 66b710890..4e7d22280 100644 --- a/lib/crewai/src/crewai/lite_agent.py +++ b/lib/crewai/src/crewai/lite_agent.py @@ -600,7 +600,7 @@ class LiteAgent(FlowTrackable, BaseModel): def _save_to_memory(self, output_text: str) -> None: """Extract discrete memories from the run and remember each. No-op if _memory is None or read-only.""" - if self._memory is None or getattr(self._memory, "_read_only", False): + if self._memory is None or self._memory.read_only: return input_str = self._get_last_user_content() or "User request" try: diff --git a/lib/crewai/src/crewai/memory/memory_scope.py b/lib/crewai/src/crewai/memory/memory_scope.py index 705ec07de..6c252f9f2 100644 --- a/lib/crewai/src/crewai/memory/memory_scope.py +++ b/lib/crewai/src/crewai/memory/memory_scope.py @@ -3,11 +3,9 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, Any +from typing import Any, Literal - -if TYPE_CHECKING: - from crewai.memory.unified_memory import Memory +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from crewai.memory.types import ( _RECALL_OVERSAMPLE_FACTOR, @@ -15,22 +13,38 @@ from crewai.memory.types import ( MemoryRecord, ScopeInfo, ) +from crewai.memory.unified_memory import Memory -class MemoryScope: +class MemoryScope(BaseModel): """View of Memory restricted to a root path. All operations are scoped under that path.""" - def __init__(self, memory: Memory, root_path: str) -> None: - """Initialize scope. + model_config = ConfigDict(arbitrary_types_allowed=True) - Args: - memory: The underlying Memory instance. - root_path: Root path for this scope (e.g. /agent/1). - """ - self._memory = memory - self._root = root_path.rstrip("/") or "" - if self._root and not self._root.startswith("/"): - self._root = "/" + self._root + root_path: str = Field(default="/") + + _memory: Memory = PrivateAttr() + _root: str = PrivateAttr() + + @model_validator(mode="wrap") + @classmethod + def _accept_memory(cls, data: Any, handler: Any) -> MemoryScope: + """Extract memory dependency and normalize root path before validation.""" + if isinstance(data, MemoryScope): + return data + memory = data.pop("memory") + instance: MemoryScope = handler(data) + instance._memory = memory + root = instance.root_path.rstrip("/") or "" + if root and not root.startswith("/"): + root = "/" + root + instance._root = root + return instance + + @property + def read_only(self) -> bool: + """Whether the underlying memory is read-only.""" + return self._memory.read_only def _scope_path(self, scope: str | None) -> str: if not scope or scope == "/": @@ -52,7 +66,7 @@ class MemoryScope: importance: float | None = None, source: str | None = None, private: bool = False, - ) -> MemoryRecord: + ) -> MemoryRecord | None: """Remember content; scope is relative to this scope's root.""" path = self._scope_path(scope) return self._memory.remember( @@ -71,7 +85,7 @@ class MemoryScope: scope: str | None = None, categories: list[str] | None = None, limit: int = 10, - depth: str = "deep", + depth: Literal["shallow", "deep"] = "deep", source: str | None = None, include_private: bool = False, ) -> list[MemoryMatch]: @@ -138,34 +152,34 @@ class MemoryScope: """Return a narrower scope under this scope.""" child = path.strip("/") if not child: - return MemoryScope(self._memory, self._root or "/") + return MemoryScope(memory=self._memory, root_path=self._root or "/") base = self._root.rstrip("/") or "" new_root = f"{base}/{child}" if base else f"/{child}" - return MemoryScope(self._memory, new_root) + return MemoryScope(memory=self._memory, root_path=new_root) -class MemorySlice: +class MemorySlice(BaseModel): """View over multiple scopes: recall searches all, remember is a no-op when read_only.""" - def __init__( - self, - memory: Memory, - scopes: list[str], - categories: list[str] | None = None, - read_only: bool = True, - ) -> None: - """Initialize slice. + model_config = ConfigDict(arbitrary_types_allowed=True) - Args: - memory: The underlying Memory instance. - scopes: List of scope paths to include. - categories: Optional category filter for recall. - read_only: If True, remember() is a silent no-op. - """ - self._memory = memory - self._scopes = [s.rstrip("/") or "/" for s in scopes] - self._categories = categories - self._read_only = read_only + scopes: list[str] = Field(default_factory=list) + categories: list[str] | None = Field(default=None) + read_only: bool = Field(default=True) + + _memory: Memory = PrivateAttr() + + @model_validator(mode="wrap") + @classmethod + def _accept_memory(cls, data: Any, handler: Any) -> MemorySlice: + """Extract memory dependency and normalize scopes before validation.""" + if isinstance(data, MemorySlice): + return data + memory = data.pop("memory") + data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])] + instance: MemorySlice = handler(data) + instance._memory = memory + return instance def remember( self, @@ -178,7 +192,7 @@ class MemorySlice: private: bool = False, ) -> MemoryRecord | None: """Remember into an explicit scope. No-op when read_only=True.""" - if self._read_only: + if self.read_only: return None return self._memory.remember( content, @@ -196,14 +210,14 @@ class MemorySlice: scope: str | None = None, categories: list[str] | None = None, limit: int = 10, - depth: str = "deep", + depth: Literal["shallow", "deep"] = "deep", source: str | None = None, include_private: bool = False, ) -> list[MemoryMatch]: """Recall across all slice scopes; results merged and re-ranked.""" - cats = categories or self._categories + cats = categories or self.categories all_matches: list[MemoryMatch] = [] - for sc in self._scopes: + for sc in self.scopes: matches = self._memory.recall( query, scope=sc, @@ -231,7 +245,7 @@ class MemorySlice: def list_scopes(self, path: str = "/") -> list[str]: """List scopes across all slice roots.""" out: list[str] = [] - for sc in self._scopes: + for sc in self.scopes: full = f"{sc.rstrip('/')}{path}" if sc != "/" else path out.extend(self._memory.list_scopes(full)) return sorted(set(out)) @@ -243,15 +257,23 @@ class MemorySlice: oldest: datetime | None = None newest: datetime | None = None children: list[str] = [] - for sc in self._scopes: + for sc in self.scopes: full = f"{sc.rstrip('/')}{path}" if sc != "/" else path inf = self._memory.info(full) total_records += inf.record_count all_categories.update(inf.categories) if inf.oldest_record: - oldest = inf.oldest_record if oldest is None else min(oldest, inf.oldest_record) + oldest = ( + inf.oldest_record + if oldest is None + else min(oldest, inf.oldest_record) + ) if inf.newest_record: - newest = inf.newest_record if newest is None else max(newest, inf.newest_record) + newest = ( + inf.newest_record + if newest is None + else max(newest, inf.newest_record) + ) children.extend(inf.child_scopes) return ScopeInfo( path=path, @@ -265,7 +287,7 @@ class MemorySlice: def list_categories(self, path: str | None = None) -> dict[str, int]: """Categories and counts across slice scopes.""" counts: dict[str, int] = {} - for sc in self._scopes: + for sc in self.scopes: full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc for k, v in self._memory.list_categories(full).items(): counts[k] = counts.get(k, 0) + v diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index cae9013bd..cb4954c39 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -6,7 +6,9 @@ from concurrent.futures import Future, ThreadPoolExecutor from datetime import datetime import threading import time -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal + +from pydantic import BaseModel, ConfigDict, Field, PlainValidator, PrivateAttr from crewai.events.event_bus import crewai_event_bus from crewai.events.types.memory_events import ( @@ -39,13 +41,18 @@ if TYPE_CHECKING: ) +def _passthrough(v: Any) -> Any: + """PlainValidator that accepts any value, bypassing strict union discrimination.""" + return v + + def _default_embedder() -> OpenAIEmbeddingFunction: """Build default OpenAI embedder for memory.""" spec: OpenAIProviderSpec = {"provider": "openai", "config": {}} return build_embedder(spec) -class Memory: +class Memory(BaseModel): """Unified memory: standalone, LLM-analyzed, with intelligent recall flow. Works without agent/crew. Uses LLM to infer scope, categories, importance on save. @@ -53,116 +60,119 @@ class Memory: pluggable storage (LanceDB default). """ - def __init__( - self, - llm: BaseLLM | str = "gpt-4o-mini", - storage: StorageBackend | str = "lancedb", - embedder: Any = None, - # -- Scoring weights -- - # These three weights control how recall results are ranked. - # The composite score is: semantic_weight * similarity + recency_weight * decay + importance_weight * importance. - # They should sum to ~1.0 for intuitive scoring. - recency_weight: float = 0.3, - semantic_weight: float = 0.5, - importance_weight: float = 0.2, - # How quickly old memories lose relevance. The recency score halves every - # N days (exponential decay). Lower = faster forgetting; higher = longer relevance. - recency_half_life_days: int = 30, - # -- Consolidation -- - # When remembering new content, if an existing record has similarity >= this - # threshold, the LLM is asked to merge/update/delete. Set to 1.0 to disable. - consolidation_threshold: float = 0.85, - # Max existing records to compare against when checking for consolidation. - consolidation_limit: int = 5, - # -- Save defaults -- - # Importance assigned to new memories when no explicit value is given and - # the LLM analysis path is skipped (all fields provided by the caller). - default_importance: float = 0.5, - # -- Recall depth control -- - # These thresholds govern the RecallFlow router that decides between - # returning results immediately ("synthesize") vs. doing an extra - # LLM-driven exploration round ("explore_deeper"). - # confidence >= confidence_threshold_high => always synthesize - # confidence < confidence_threshold_low => explore deeper (if budget > 0) - # complex query + confidence < complex_query_threshold => explore deeper - confidence_threshold_high: float = 0.8, - confidence_threshold_low: float = 0.5, - complex_query_threshold: float = 0.7, - # How many LLM-driven exploration rounds the RecallFlow is allowed to run. - # 0 = always shallow (vector search only); higher = more thorough but slower. - exploration_budget: int = 1, - # Queries shorter than this skip LLM analysis (saving ~1-3s). - # Longer queries (full task descriptions) benefit from LLM distillation. - query_analysis_threshold: int = 200, - # When True, all write operations (remember, remember_many) are silently - # skipped. Useful for sharing a read-only view of memory across agents - # without any of them persisting new memories. - read_only: bool = False, - ) -> None: - """Initialize Memory. + model_config = ConfigDict(arbitrary_types_allowed=True) - Args: - llm: LLM for analysis (model name or BaseLLM instance). - storage: Backend: "lancedb" or a StorageBackend instance. - embedder: Embedding callable, provider config dict, or None (default OpenAI). - recency_weight: Weight for recency in the composite relevance score. - semantic_weight: Weight for semantic similarity in the composite relevance score. - importance_weight: Weight for importance in the composite relevance score. - recency_half_life_days: Recency score halves every N days (exponential decay). - consolidation_threshold: Similarity above which consolidation is triggered on save. - consolidation_limit: Max existing records to compare during consolidation. - default_importance: Default importance when not provided or inferred. - confidence_threshold_high: Recall confidence above which results are returned directly. - confidence_threshold_low: Recall confidence below which deeper exploration is triggered. - complex_query_threshold: For complex queries, explore deeper below this confidence. - exploration_budget: Number of LLM-driven exploration rounds during deep recall. - query_analysis_threshold: Queries shorter than this skip LLM analysis during deep recall. - read_only: If True, remember() and remember_many() are silent no-ops. - """ - self._read_only = read_only + llm: Annotated[BaseLLM | str, PlainValidator(_passthrough)] = Field( + default="gpt-4o-mini", + description="LLM for analysis (model name or BaseLLM instance).", + ) + storage: Annotated[StorageBackend | str, PlainValidator(_passthrough)] = Field( + default="lancedb", + description="Storage backend instance or path string.", + ) + embedder: Any = Field( + default=None, + description="Embedding callable, provider config dict, or None for default OpenAI.", + ) + recency_weight: float = Field( + default=0.3, + description="Weight for recency in the composite relevance score.", + ) + semantic_weight: float = Field( + default=0.5, + description="Weight for semantic similarity in the composite relevance score.", + ) + importance_weight: float = Field( + default=0.2, + description="Weight for importance in the composite relevance score.", + ) + recency_half_life_days: int = Field( + default=30, + description="Recency score halves every N days (exponential decay).", + ) + consolidation_threshold: float = Field( + default=0.85, + description="Similarity above which consolidation is triggered on save.", + ) + consolidation_limit: int = Field( + default=5, + description="Max existing records to compare during consolidation.", + ) + default_importance: float = Field( + default=0.5, + description="Default importance when not provided or inferred.", + ) + confidence_threshold_high: float = Field( + default=0.8, + description="Recall confidence above which results are returned directly.", + ) + confidence_threshold_low: float = Field( + default=0.5, + description="Recall confidence below which deeper exploration is triggered.", + ) + complex_query_threshold: float = Field( + default=0.7, + description="For complex queries, explore deeper below this confidence.", + ) + exploration_budget: int = Field( + default=1, + description="Number of LLM-driven exploration rounds during deep recall.", + ) + query_analysis_threshold: int = Field( + default=200, + description="Queries shorter than this skip LLM analysis during deep recall.", + ) + read_only: bool = Field( + default=False, + description="If True, remember() and remember_many() are silent no-ops.", + ) + + _config: MemoryConfig = PrivateAttr() + _llm_instance: BaseLLM | None = PrivateAttr(default=None) + _embedder_instance: Any = PrivateAttr(default=None) + _storage: StorageBackend = PrivateAttr() + _save_pool: ThreadPoolExecutor = PrivateAttr( + default_factory=lambda: ThreadPoolExecutor( + max_workers=1, thread_name_prefix="memory-save" + ) + ) + _pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list) + _pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + + def model_post_init(self, __context: Any) -> None: + """Initialize runtime state from field values.""" self._config = MemoryConfig( - recency_weight=recency_weight, - semantic_weight=semantic_weight, - importance_weight=importance_weight, - recency_half_life_days=recency_half_life_days, - consolidation_threshold=consolidation_threshold, - consolidation_limit=consolidation_limit, - default_importance=default_importance, - confidence_threshold_high=confidence_threshold_high, - confidence_threshold_low=confidence_threshold_low, - complex_query_threshold=complex_query_threshold, - exploration_budget=exploration_budget, - query_analysis_threshold=query_analysis_threshold, + recency_weight=self.recency_weight, + semantic_weight=self.semantic_weight, + importance_weight=self.importance_weight, + recency_half_life_days=self.recency_half_life_days, + consolidation_threshold=self.consolidation_threshold, + consolidation_limit=self.consolidation_limit, + default_importance=self.default_importance, + confidence_threshold_high=self.confidence_threshold_high, + confidence_threshold_low=self.confidence_threshold_low, + complex_query_threshold=self.complex_query_threshold, + exploration_budget=self.exploration_budget, + query_analysis_threshold=self.query_analysis_threshold, ) - # Store raw config for lazy initialization. LLM and embedder are only - # built on first access so that Memory() never fails at construction - # time (e.g. when auto-created by Flow without an API key set). - self._llm_config: BaseLLM | str = llm - self._llm_instance: BaseLLM | None = None if isinstance(llm, str) else llm - self._embedder_config: Any = embedder - self._embedder_instance: Any = ( - embedder - if (embedder is not None and not isinstance(embedder, dict)) + self._llm_instance = None if isinstance(self.llm, str) else self.llm + self._embedder_instance = ( + self.embedder + if (self.embedder is not None and not isinstance(self.embedder, dict)) else None ) - if isinstance(storage, str): + if isinstance(self.storage, str): from crewai.memory.storage.lancedb_storage import LanceDBStorage - self._storage = LanceDBStorage() if storage == "lancedb" else LanceDBStorage(path=storage) + self._storage = ( + LanceDBStorage() + if self.storage == "lancedb" + else LanceDBStorage(path=self.storage) + ) else: - self._storage = storage - - # Background save queue. max_workers=1 serializes saves to avoid - # concurrent storage mutations (two saves finding the same similar - # record and both trying to update/delete it). Within each save, - # the parallel LLM calls still run on their own thread pool. - self._save_pool = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="memory-save" - ) - self._pending_saves: list[Future[Any]] = [] - self._pending_lock = threading.Lock() + self._storage = self.storage _MEMORY_DOCS_URL = "https://docs.crewai.com/concepts/memory" @@ -173,11 +183,7 @@ class Memory: from crewai.llm import LLM try: - model_name = ( - self._llm_config - if isinstance(self._llm_config, str) - else str(self._llm_config) - ) + model_name = self.llm if isinstance(self.llm, str) else str(self.llm) self._llm_instance = LLM(model=model_name) except Exception as e: raise RuntimeError( @@ -197,8 +203,8 @@ class Memory: """Lazy embedder initialization -- only created when first needed.""" if self._embedder_instance is None: try: - if isinstance(self._embedder_config, dict): - self._embedder_instance = build_embedder(self._embedder_config) + if isinstance(self.embedder, dict): + self._embedder_instance = build_embedder(self.embedder) else: self._embedder_instance = _default_embedder() except Exception as e: @@ -356,7 +362,7 @@ class Memory: Raises: Exception: On save failure (events emitted). """ - if self._read_only: + if self.read_only: return None _source_type = "unified_memory" try: @@ -444,7 +450,7 @@ class Memory: Returns: Empty list (records are not available until the background save completes). """ - if not contents or self._read_only: + if not contents or self.read_only: return [] self._submit_save( diff --git a/lib/crewai/src/crewai/tools/memory_tools.py b/lib/crewai/src/crewai/tools/memory_tools.py index 9e4df03e9..c1874a532 100644 --- a/lib/crewai/src/crewai/tools/memory_tools.py +++ b/lib/crewai/src/crewai/tools/memory_tools.py @@ -121,7 +121,7 @@ def create_memory_tools(memory: Any) -> list[BaseTool]: description=i18n.tools("recall_memory"), ), ] - if not getattr(memory, "_read_only", False): + if not memory.read_only: tools.append( RememberTool( memory=memory, diff --git a/lib/crewai/tests/agents/test_lite_agent.py b/lib/crewai/tests/agents/test_lite_agent.py index ac03ffc28..0d7093f82 100644 --- a/lib/crewai/tests/agents/test_lite_agent.py +++ b/lib/crewai/tests/agents/test_lite_agent.py @@ -1136,7 +1136,7 @@ def test_lite_agent_memory_instance_recall_and_save_called(): successful_requests=1, ) mock_memory = Mock() - mock_memory._read_only = False + mock_memory.read_only = False mock_memory.recall.return_value = [] mock_memory.extract_memories.return_value = ["Fact one.", "Fact two."] diff --git a/lib/crewai/tests/memory/test_unified_memory.py b/lib/crewai/tests/memory/test_unified_memory.py index 26e2a1929..98a041086 100644 --- a/lib/crewai/tests/memory/test_unified_memory.py +++ b/lib/crewai/tests/memory/test_unified_memory.py @@ -172,8 +172,8 @@ def test_memory_scope_slice(tmp_path: Path, mock_embedder: MagicMock) -> None: sc = mem.scope("/agent/1") assert sc._root in ("/agent/1", "/agent/1/") sl = mem.slice(["/a", "/b"], read_only=True) - assert sl._read_only is True - assert "/a" in sl._scopes and "/b" in sl._scopes + assert sl.read_only is True + assert "/a" in sl.scopes and "/b" in sl.scopes def test_memory_list_scopes_info_tree(tmp_path: Path, mock_embedder: MagicMock) -> None: @@ -198,7 +198,7 @@ def test_memory_scope_remember_recall(tmp_path: Path, mock_embedder: MagicMock) from crewai.memory.memory_scope import MemoryScope mem = Memory(storage=str(tmp_path / "db5"), llm=MagicMock(), embedder=mock_embedder) - scope = MemoryScope(mem, "/crew/1") + scope = MemoryScope(memory=mem, root_path="/crew/1") scope.remember("Scoped note", scope="/", categories=[], importance=0.5, metadata={}) results = scope.recall("note", limit=5, depth="shallow") assert len(results) >= 1 @@ -213,7 +213,7 @@ def test_memory_slice_recall(tmp_path: Path, mock_embedder: MagicMock) -> None: mem = Memory(storage=str(tmp_path / "db6"), llm=MagicMock(), embedder=mock_embedder) mem.remember("In scope A", scope="/a", categories=[], importance=0.5, metadata={}) - sl = MemorySlice(mem, ["/a"], read_only=True) + sl = MemorySlice(memory=mem, scopes=["/a"], read_only=True) matches = sl.recall("scope", limit=5, depth="shallow") assert isinstance(matches, list) @@ -223,7 +223,7 @@ def test_memory_slice_remember_is_noop_when_read_only(tmp_path: Path, mock_embed from crewai.memory.memory_scope import MemorySlice mem = Memory(storage=str(tmp_path / "db7"), llm=MagicMock(), embedder=mock_embedder) - sl = MemorySlice(mem, ["/a"], read_only=True) + sl = MemorySlice(memory=mem, scopes=["/a"], read_only=True) result = sl.remember("x", scope="/a") assert result is None assert mem.list_records() == [] @@ -319,7 +319,7 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None: from crewai.agents.parser import AgentFinish mock_memory = MagicMock() - mock_memory._read_only = False + mock_memory.read_only = False mock_memory.extract_memories.return_value = ["Fact A.", "Fact B."] mock_agent = MagicMock() @@ -360,7 +360,7 @@ def test_executor_save_to_memory_skips_delegation_output() -> None: from crewai.utilities.string_utils import sanitize_tool_name mock_memory = MagicMock() - mock_memory._read_only = False + mock_memory.read_only = False mock_agent = MagicMock() mock_agent.memory = mock_memory mock_agent._logger = MagicMock() @@ -393,7 +393,7 @@ def test_memory_scope_extract_memories_delegates() -> None: mock_memory = MagicMock() mock_memory.extract_memories.return_value = ["Scoped fact."] - scope = MemoryScope(mock_memory, "/agent/1") + scope = MemoryScope(memory=mock_memory, root_path="/agent/1") result = scope.extract_memories("Some content") mock_memory.extract_memories.assert_called_once_with("Some content") assert result == ["Scoped fact."] @@ -405,7 +405,7 @@ def test_memory_slice_extract_memories_delegates() -> None: mock_memory = MagicMock() mock_memory.extract_memories.return_value = ["Sliced fact."] - sl = MemorySlice(mock_memory, ["/a", "/b"], read_only=True) + sl = MemorySlice(memory=mock_memory, scopes=["/a", "/b"], read_only=True) result = sl.extract_memories("Some content") mock_memory.extract_memories.assert_called_once_with("Some content") assert result == ["Sliced fact."] @@ -670,10 +670,10 @@ def test_agent_kickoff_memory_recall_and_save(tmp_path: Path, mock_embedder: Mag verbose=False, ) - # Mock recall to verify it's called, but return real results - with patch.object(mem, "recall", wraps=mem.recall) as recall_mock, \ - patch.object(mem, "extract_memories", return_value=["PostgreSQL is used."]) as extract_mock, \ - patch.object(mem, "remember_many", wraps=mem.remember_many) as remember_many_mock: + # Patch on the class to avoid Pydantic BaseModel __delattr__ restriction + with patch.object(Memory, "recall", wraps=mem.recall) as recall_mock, \ + patch.object(Memory, "extract_memories", return_value=["PostgreSQL is used."]) as extract_mock, \ + patch.object(Memory, "remember_many", wraps=mem.remember_many) as remember_many_mock: result = agent.kickoff("What database do we use?") assert result is not None diff --git a/lib/crewai/tests/test_crew.py b/lib/crewai/tests/test_crew.py index 64d122a7c..adcdfda4c 100644 --- a/lib/crewai/tests/test_crew.py +++ b/lib/crewai/tests/test_crew.py @@ -36,7 +36,7 @@ from crewai.flow import Flow, start from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.llm import LLM - +from crewai.memory.unified_memory import Memory from crewai.process import Process from crewai.project import CrewBase, agent, before_kickoff, crew, task from crewai.task import Task @@ -2618,9 +2618,9 @@ def test_memory_remember_called_after_task(): ) with patch.object( - crew._memory, "extract_memories", wraps=crew._memory.extract_memories + Memory, "extract_memories", wraps=crew._memory.extract_memories ) as extract_mock, patch.object( - crew._memory, "remember", wraps=crew._memory.remember + Memory, "remember", wraps=crew._memory.remember ) as remember_mock: crew.kickoff() @@ -4773,13 +4773,13 @@ def test_memory_remember_receives_task_content(): # Mock extract_memories to return fake memories and capture the raw input. # No wraps= needed -- the test only checks what args it receives, not the output. patch.object( - crew._memory, "extract_memories", return_value=["Fake memory."] + Memory, "extract_memories", return_value=["Fake memory."] ) as extract_mock, # Mock recall to avoid LLM calls for query analysis (not in cassette). - patch.object(crew._memory, "recall", return_value=[]), + patch.object(Memory, "recall", return_value=[]), # Mock remember_many to prevent the background save from triggering # LLM calls (field resolution) that aren't in the cassette. - patch.object(crew._memory, "remember_many", return_value=[]), + patch.object(Memory, "remember_many", return_value=[]), ): crew.kickoff()