merge: resolve conflict in unified_memory.py with main

This commit is contained in:
Greyson LaLonde
2026-03-09 02:01:46 -04:00
8 changed files with 196 additions and 180 deletions

View File

@@ -30,7 +30,7 @@ class CrewAgentExecutorMixin:
memory = getattr(self.agent, "memory", None) or ( memory = getattr(self.agent, "memory", None) or (
getattr(self.crew, "_memory", None) if self.crew else None getattr(self.crew, "_memory", None) if self.crew else None
) )
if memory is None or not self.task or getattr(memory, "_read_only", False): if memory is None or not self.task or memory.read_only:
return return
if f"Action: {sanitize_tool_name('Delegate work to coworker')}" in output.text: if f"Action: {sanitize_tool_name('Delegate work to coworker')}" in output.text:
return return

View File

@@ -600,7 +600,7 @@ class LiteAgent(FlowTrackable, BaseModel):
def _save_to_memory(self, output_text: str) -> None: def _save_to_memory(self, output_text: str) -> None:
"""Extract discrete memories from the run and remember each. No-op if _memory is None or read-only.""" """Extract discrete memories from the run and remember each. No-op if _memory is None or read-only."""
if self._memory is None or getattr(self._memory, "_read_only", False): if self._memory is None or self._memory.read_only:
return return
input_str = self._get_last_user_content() or "User request" input_str = self._get_last_user_content() or "User request"
try: try:

View File

@@ -3,11 +3,9 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
if TYPE_CHECKING:
from crewai.memory.unified_memory import Memory
from crewai.memory.types import ( from crewai.memory.types import (
_RECALL_OVERSAMPLE_FACTOR, _RECALL_OVERSAMPLE_FACTOR,
@@ -15,22 +13,38 @@ from crewai.memory.types import (
MemoryRecord, MemoryRecord,
ScopeInfo, ScopeInfo,
) )
from crewai.memory.unified_memory import Memory
class MemoryScope: class MemoryScope(BaseModel):
"""View of Memory restricted to a root path. All operations are scoped under that path.""" """View of Memory restricted to a root path. All operations are scoped under that path."""
def __init__(self, memory: Memory, root_path: str) -> None: model_config = ConfigDict(arbitrary_types_allowed=True)
"""Initialize scope.
Args: root_path: str = Field(default="/")
memory: The underlying Memory instance.
root_path: Root path for this scope (e.g. /agent/1). _memory: Memory = PrivateAttr()
""" _root: str = PrivateAttr()
self._memory = memory
self._root = root_path.rstrip("/") or "" @model_validator(mode="wrap")
if self._root and not self._root.startswith("/"): @classmethod
self._root = "/" + self._root def _accept_memory(cls, data: Any, handler: Any) -> MemoryScope:
"""Extract memory dependency and normalize root path before validation."""
if isinstance(data, MemoryScope):
return data
memory = data.pop("memory")
instance: MemoryScope = handler(data)
instance._memory = memory
root = instance.root_path.rstrip("/") or ""
if root and not root.startswith("/"):
root = "/" + root
instance._root = root
return instance
@property
def read_only(self) -> bool:
"""Whether the underlying memory is read-only."""
return self._memory.read_only
def _scope_path(self, scope: str | None) -> str: def _scope_path(self, scope: str | None) -> str:
if not scope or scope == "/": if not scope or scope == "/":
@@ -52,7 +66,7 @@ class MemoryScope:
importance: float | None = None, importance: float | None = None,
source: str | None = None, source: str | None = None,
private: bool = False, private: bool = False,
) -> MemoryRecord: ) -> MemoryRecord | None:
"""Remember content; scope is relative to this scope's root.""" """Remember content; scope is relative to this scope's root."""
path = self._scope_path(scope) path = self._scope_path(scope)
return self._memory.remember( return self._memory.remember(
@@ -71,7 +85,7 @@ class MemoryScope:
scope: str | None = None, scope: str | None = None,
categories: list[str] | None = None, categories: list[str] | None = None,
limit: int = 10, limit: int = 10,
depth: str = "deep", depth: Literal["shallow", "deep"] = "deep",
source: str | None = None, source: str | None = None,
include_private: bool = False, include_private: bool = False,
) -> list[MemoryMatch]: ) -> list[MemoryMatch]:
@@ -138,34 +152,34 @@ class MemoryScope:
"""Return a narrower scope under this scope.""" """Return a narrower scope under this scope."""
child = path.strip("/") child = path.strip("/")
if not child: if not child:
return MemoryScope(self._memory, self._root or "/") return MemoryScope(memory=self._memory, root_path=self._root or "/")
base = self._root.rstrip("/") or "" base = self._root.rstrip("/") or ""
new_root = f"{base}/{child}" if base else f"/{child}" new_root = f"{base}/{child}" if base else f"/{child}"
return MemoryScope(self._memory, new_root) return MemoryScope(memory=self._memory, root_path=new_root)
class MemorySlice: class MemorySlice(BaseModel):
"""View over multiple scopes: recall searches all, remember is a no-op when read_only.""" """View over multiple scopes: recall searches all, remember is a no-op when read_only."""
def __init__( model_config = ConfigDict(arbitrary_types_allowed=True)
self,
memory: Memory,
scopes: list[str],
categories: list[str] | None = None,
read_only: bool = True,
) -> None:
"""Initialize slice.
Args: scopes: list[str] = Field(default_factory=list)
memory: The underlying Memory instance. categories: list[str] | None = Field(default=None)
scopes: List of scope paths to include. read_only: bool = Field(default=True)
categories: Optional category filter for recall.
read_only: If True, remember() is a silent no-op. _memory: Memory = PrivateAttr()
"""
self._memory = memory @model_validator(mode="wrap")
self._scopes = [s.rstrip("/") or "/" for s in scopes] @classmethod
self._categories = categories def _accept_memory(cls, data: Any, handler: Any) -> MemorySlice:
self._read_only = read_only """Extract memory dependency and normalize scopes before validation."""
if isinstance(data, MemorySlice):
return data
memory = data.pop("memory")
data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])]
instance: MemorySlice = handler(data)
instance._memory = memory
return instance
def remember( def remember(
self, self,
@@ -178,7 +192,7 @@ class MemorySlice:
private: bool = False, private: bool = False,
) -> MemoryRecord | None: ) -> MemoryRecord | None:
"""Remember into an explicit scope. No-op when read_only=True.""" """Remember into an explicit scope. No-op when read_only=True."""
if self._read_only: if self.read_only:
return None return None
return self._memory.remember( return self._memory.remember(
content, content,
@@ -196,14 +210,14 @@ class MemorySlice:
scope: str | None = None, scope: str | None = None,
categories: list[str] | None = None, categories: list[str] | None = None,
limit: int = 10, limit: int = 10,
depth: str = "deep", depth: Literal["shallow", "deep"] = "deep",
source: str | None = None, source: str | None = None,
include_private: bool = False, include_private: bool = False,
) -> list[MemoryMatch]: ) -> list[MemoryMatch]:
"""Recall across all slice scopes; results merged and re-ranked.""" """Recall across all slice scopes; results merged and re-ranked."""
cats = categories or self._categories cats = categories or self.categories
all_matches: list[MemoryMatch] = [] all_matches: list[MemoryMatch] = []
for sc in self._scopes: for sc in self.scopes:
matches = self._memory.recall( matches = self._memory.recall(
query, query,
scope=sc, scope=sc,
@@ -231,7 +245,7 @@ class MemorySlice:
def list_scopes(self, path: str = "/") -> list[str]: def list_scopes(self, path: str = "/") -> list[str]:
"""List scopes across all slice roots.""" """List scopes across all slice roots."""
out: list[str] = [] out: list[str] = []
for sc in self._scopes: for sc in self.scopes:
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
out.extend(self._memory.list_scopes(full)) out.extend(self._memory.list_scopes(full))
return sorted(set(out)) return sorted(set(out))
@@ -243,7 +257,7 @@ class MemorySlice:
oldest: datetime | None = None oldest: datetime | None = None
newest: datetime | None = None newest: datetime | None = None
children: list[str] = [] children: list[str] = []
for sc in self._scopes: for sc in self.scopes:
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
inf = self._memory.info(full) inf = self._memory.info(full)
total_records += inf.record_count total_records += inf.record_count
@@ -273,7 +287,7 @@ class MemorySlice:
def list_categories(self, path: str | None = None) -> dict[str, int]: def list_categories(self, path: str | None = None) -> dict[str, int]:
"""Categories and counts across slice scopes.""" """Categories and counts across slice scopes."""
counts: dict[str, int] = {} counts: dict[str, int] = {}
for sc in self._scopes: for sc in self.scopes:
full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc
for k, v in self._memory.list_categories(full).items(): for k, v in self._memory.list_categories(full).items():
counts[k] = counts.get(k, 0) + v counts[k] = counts.get(k, 0) + v

View File

@@ -6,7 +6,9 @@ from concurrent.futures import Future, ThreadPoolExecutor
from datetime import datetime from datetime import datetime
import threading import threading
import time import time
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Annotated, Any, Literal
from pydantic import BaseModel, ConfigDict, Field, PlainValidator, PrivateAttr
from crewai.events.event_bus import crewai_event_bus from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import ( from crewai.events.types.memory_events import (
@@ -39,13 +41,18 @@ if TYPE_CHECKING:
) )
def _passthrough(v: Any) -> Any:
"""PlainValidator that accepts any value, bypassing strict union discrimination."""
return v
def _default_embedder() -> OpenAIEmbeddingFunction: def _default_embedder() -> OpenAIEmbeddingFunction:
"""Build default OpenAI embedder for memory.""" """Build default OpenAI embedder for memory."""
spec: OpenAIProviderSpec = {"provider": "openai", "config": {}} spec: OpenAIProviderSpec = {"provider": "openai", "config": {}}
return build_embedder(spec) return build_embedder(spec)
class Memory: class Memory(BaseModel):
"""Unified memory: standalone, LLM-analyzed, with intelligent recall flow. """Unified memory: standalone, LLM-analyzed, with intelligent recall flow.
Works without agent/crew. Uses LLM to infer scope, categories, importance on save. Works without agent/crew. Uses LLM to infer scope, categories, importance on save.
@@ -53,120 +60,119 @@ class Memory:
pluggable storage (LanceDB default). pluggable storage (LanceDB default).
""" """
def __init__( model_config = ConfigDict(arbitrary_types_allowed=True)
self,
llm: BaseLLM | str = "gpt-4o-mini",
storage: StorageBackend | str = "lancedb",
embedder: Any = None,
# -- Scoring weights --
# These three weights control how recall results are ranked.
# The composite score is: semantic_weight * similarity + recency_weight * decay + importance_weight * importance.
# They should sum to ~1.0 for intuitive scoring.
recency_weight: float = 0.3,
semantic_weight: float = 0.5,
importance_weight: float = 0.2,
# How quickly old memories lose relevance. The recency score halves every
# N days (exponential decay). Lower = faster forgetting; higher = longer relevance.
recency_half_life_days: int = 30,
# -- Consolidation --
# When remembering new content, if an existing record has similarity >= this
# threshold, the LLM is asked to merge/update/delete. Set to 1.0 to disable.
consolidation_threshold: float = 0.85,
# Max existing records to compare against when checking for consolidation.
consolidation_limit: int = 5,
# -- Save defaults --
# Importance assigned to new memories when no explicit value is given and
# the LLM analysis path is skipped (all fields provided by the caller).
default_importance: float = 0.5,
# -- Recall depth control --
# These thresholds govern the RecallFlow router that decides between
# returning results immediately ("synthesize") vs. doing an extra
# LLM-driven exploration round ("explore_deeper").
# confidence >= confidence_threshold_high => always synthesize
# confidence < confidence_threshold_low => explore deeper (if budget > 0)
# complex query + confidence < complex_query_threshold => explore deeper
confidence_threshold_high: float = 0.8,
confidence_threshold_low: float = 0.5,
complex_query_threshold: float = 0.7,
# How many LLM-driven exploration rounds the RecallFlow is allowed to run.
# 0 = always shallow (vector search only); higher = more thorough but slower.
exploration_budget: int = 1,
# Queries shorter than this skip LLM analysis (saving ~1-3s).
# Longer queries (full task descriptions) benefit from LLM distillation.
query_analysis_threshold: int = 200,
# When True, all write operations (remember, remember_many) are silently
# skipped. Useful for sharing a read-only view of memory across agents
# without any of them persisting new memories.
read_only: bool = False,
) -> None:
"""Initialize Memory.
Args: llm: Annotated[BaseLLM | str, PlainValidator(_passthrough)] = Field(
llm: LLM for analysis (model name or BaseLLM instance). default="gpt-4o-mini",
storage: Backend: "lancedb" or a StorageBackend instance. description="LLM for analysis (model name or BaseLLM instance).",
embedder: Embedding callable, provider config dict, or None (default OpenAI). )
recency_weight: Weight for recency in the composite relevance score. storage: Annotated[StorageBackend | str, PlainValidator(_passthrough)] = Field(
semantic_weight: Weight for semantic similarity in the composite relevance score. default="lancedb",
importance_weight: Weight for importance in the composite relevance score. description="Storage backend instance or path string.",
recency_half_life_days: Recency score halves every N days (exponential decay). )
consolidation_threshold: Similarity above which consolidation is triggered on save. embedder: Any = Field(
consolidation_limit: Max existing records to compare during consolidation. default=None,
default_importance: Default importance when not provided or inferred. description="Embedding callable, provider config dict, or None for default OpenAI.",
confidence_threshold_high: Recall confidence above which results are returned directly. )
confidence_threshold_low: Recall confidence below which deeper exploration is triggered. recency_weight: float = Field(
complex_query_threshold: For complex queries, explore deeper below this confidence. default=0.3,
exploration_budget: Number of LLM-driven exploration rounds during deep recall. description="Weight for recency in the composite relevance score.",
query_analysis_threshold: Queries shorter than this skip LLM analysis during deep recall. )
read_only: If True, remember() and remember_many() are silent no-ops. semantic_weight: float = Field(
""" default=0.5,
self._read_only = read_only description="Weight for semantic similarity in the composite relevance score.",
)
importance_weight: float = Field(
default=0.2,
description="Weight for importance in the composite relevance score.",
)
recency_half_life_days: int = Field(
default=30,
description="Recency score halves every N days (exponential decay).",
)
consolidation_threshold: float = Field(
default=0.85,
description="Similarity above which consolidation is triggered on save.",
)
consolidation_limit: int = Field(
default=5,
description="Max existing records to compare during consolidation.",
)
default_importance: float = Field(
default=0.5,
description="Default importance when not provided or inferred.",
)
confidence_threshold_high: float = Field(
default=0.8,
description="Recall confidence above which results are returned directly.",
)
confidence_threshold_low: float = Field(
default=0.5,
description="Recall confidence below which deeper exploration is triggered.",
)
complex_query_threshold: float = Field(
default=0.7,
description="For complex queries, explore deeper below this confidence.",
)
exploration_budget: int = Field(
default=1,
description="Number of LLM-driven exploration rounds during deep recall.",
)
query_analysis_threshold: int = Field(
default=200,
description="Queries shorter than this skip LLM analysis during deep recall.",
)
read_only: bool = Field(
default=False,
description="If True, remember() and remember_many() are silent no-ops.",
)
_config: MemoryConfig = PrivateAttr()
_llm_instance: BaseLLM | None = PrivateAttr(default=None)
_embedder_instance: Any = PrivateAttr(default=None)
_storage: StorageBackend = PrivateAttr()
_save_pool: ThreadPoolExecutor = PrivateAttr(
default_factory=lambda: ThreadPoolExecutor(
max_workers=1, thread_name_prefix="memory-save"
)
)
_pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list)
_pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
def model_post_init(self, __context: Any) -> None:
"""Initialize runtime state from field values."""
self._config = MemoryConfig( self._config = MemoryConfig(
recency_weight=recency_weight, recency_weight=self.recency_weight,
semantic_weight=semantic_weight, semantic_weight=self.semantic_weight,
importance_weight=importance_weight, importance_weight=self.importance_weight,
recency_half_life_days=recency_half_life_days, recency_half_life_days=self.recency_half_life_days,
consolidation_threshold=consolidation_threshold, consolidation_threshold=self.consolidation_threshold,
consolidation_limit=consolidation_limit, consolidation_limit=self.consolidation_limit,
default_importance=default_importance, default_importance=self.default_importance,
confidence_threshold_high=confidence_threshold_high, confidence_threshold_high=self.confidence_threshold_high,
confidence_threshold_low=confidence_threshold_low, confidence_threshold_low=self.confidence_threshold_low,
complex_query_threshold=complex_query_threshold, complex_query_threshold=self.complex_query_threshold,
exploration_budget=exploration_budget, exploration_budget=self.exploration_budget,
query_analysis_threshold=query_analysis_threshold, query_analysis_threshold=self.query_analysis_threshold,
) )
# Store raw config for lazy initialization. LLM and embedder are only self._llm_instance = None if isinstance(self.llm, str) else self.llm
# built on first access so that Memory() never fails at construction self._embedder_instance = (
# time (e.g. when auto-created by Flow without an API key set). self.embedder
self._llm_config: BaseLLM | str = llm if (self.embedder is not None and not isinstance(self.embedder, dict))
self._llm_instance: BaseLLM | None = None if isinstance(llm, str) else llm
self._embedder_config: Any = embedder
self._embedder_instance: Any = (
embedder
if (embedder is not None and not isinstance(embedder, dict))
else None else None
) )
if isinstance(storage, str): if isinstance(self.storage, str):
from crewai.memory.storage.lancedb_storage import LanceDBStorage from crewai.memory.storage.lancedb_storage import LanceDBStorage
self._storage = ( self._storage = (
LanceDBStorage() LanceDBStorage()
if storage == "lancedb" if self.storage == "lancedb"
else LanceDBStorage(path=storage) else LanceDBStorage(path=self.storage)
) )
else: else:
self._storage = storage self._storage = self.storage
# Background save queue. max_workers=1 serializes saves to avoid
# concurrent storage mutations (two saves finding the same similar
# record and both trying to update/delete it). Within each save,
# the parallel LLM calls still run on their own thread pool.
self._save_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="memory-save"
)
self._pending_saves: list[Future[Any]] = []
self._pending_lock = threading.Lock()
_MEMORY_DOCS_URL = "https://docs.crewai.com/concepts/memory" _MEMORY_DOCS_URL = "https://docs.crewai.com/concepts/memory"
@@ -177,11 +183,7 @@ class Memory:
from crewai.llm import LLM from crewai.llm import LLM
try: try:
model_name = ( model_name = self.llm if isinstance(self.llm, str) else str(self.llm)
self._llm_config
if isinstance(self._llm_config, str)
else str(self._llm_config)
)
self._llm_instance = LLM(model=model_name) self._llm_instance = LLM(model=model_name)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
@@ -201,8 +203,8 @@ class Memory:
"""Lazy embedder initialization -- only created when first needed.""" """Lazy embedder initialization -- only created when first needed."""
if self._embedder_instance is None: if self._embedder_instance is None:
try: try:
if isinstance(self._embedder_config, dict): if isinstance(self.embedder, dict):
self._embedder_instance = build_embedder(self._embedder_config) self._embedder_instance = build_embedder(self.embedder)
else: else:
self._embedder_instance = _default_embedder() self._embedder_instance = _default_embedder()
except Exception as e: except Exception as e:
@@ -360,7 +362,7 @@ class Memory:
Raises: Raises:
Exception: On save failure (events emitted). Exception: On save failure (events emitted).
""" """
if self._read_only: if self.read_only:
return None return None
_source_type = "unified_memory" _source_type = "unified_memory"
try: try:
@@ -448,7 +450,7 @@ class Memory:
Returns: Returns:
Empty list (records are not available until the background save completes). Empty list (records are not available until the background save completes).
""" """
if not contents or self._read_only: if not contents or self.read_only:
return [] return []
self._submit_save( self._submit_save(

View File

@@ -121,7 +121,7 @@ def create_memory_tools(memory: Any) -> list[BaseTool]:
description=i18n.tools("recall_memory"), description=i18n.tools("recall_memory"),
), ),
] ]
if not getattr(memory, "_read_only", False): if not memory.read_only:
tools.append( tools.append(
RememberTool( RememberTool(
memory=memory, memory=memory,

View File

@@ -1136,7 +1136,7 @@ def test_lite_agent_memory_instance_recall_and_save_called():
successful_requests=1, successful_requests=1,
) )
mock_memory = Mock() mock_memory = Mock()
mock_memory._read_only = False mock_memory.read_only = False
mock_memory.recall.return_value = [] mock_memory.recall.return_value = []
mock_memory.extract_memories.return_value = ["Fact one.", "Fact two."] mock_memory.extract_memories.return_value = ["Fact one.", "Fact two."]

View File

@@ -172,8 +172,8 @@ def test_memory_scope_slice(tmp_path: Path, mock_embedder: MagicMock) -> None:
sc = mem.scope("/agent/1") sc = mem.scope("/agent/1")
assert sc._root in ("/agent/1", "/agent/1/") assert sc._root in ("/agent/1", "/agent/1/")
sl = mem.slice(["/a", "/b"], read_only=True) sl = mem.slice(["/a", "/b"], read_only=True)
assert sl._read_only is True assert sl.read_only is True
assert "/a" in sl._scopes and "/b" in sl._scopes assert "/a" in sl.scopes and "/b" in sl.scopes
def test_memory_list_scopes_info_tree(tmp_path: Path, mock_embedder: MagicMock) -> None: def test_memory_list_scopes_info_tree(tmp_path: Path, mock_embedder: MagicMock) -> None:
@@ -198,7 +198,7 @@ def test_memory_scope_remember_recall(tmp_path: Path, mock_embedder: MagicMock)
from crewai.memory.memory_scope import MemoryScope from crewai.memory.memory_scope import MemoryScope
mem = Memory(storage=str(tmp_path / "db5"), llm=MagicMock(), embedder=mock_embedder) mem = Memory(storage=str(tmp_path / "db5"), llm=MagicMock(), embedder=mock_embedder)
scope = MemoryScope(mem, "/crew/1") scope = MemoryScope(memory=mem, root_path="/crew/1")
scope.remember("Scoped note", scope="/", categories=[], importance=0.5, metadata={}) scope.remember("Scoped note", scope="/", categories=[], importance=0.5, metadata={})
results = scope.recall("note", limit=5, depth="shallow") results = scope.recall("note", limit=5, depth="shallow")
assert len(results) >= 1 assert len(results) >= 1
@@ -213,7 +213,7 @@ def test_memory_slice_recall(tmp_path: Path, mock_embedder: MagicMock) -> None:
mem = Memory(storage=str(tmp_path / "db6"), llm=MagicMock(), embedder=mock_embedder) mem = Memory(storage=str(tmp_path / "db6"), llm=MagicMock(), embedder=mock_embedder)
mem.remember("In scope A", scope="/a", categories=[], importance=0.5, metadata={}) mem.remember("In scope A", scope="/a", categories=[], importance=0.5, metadata={})
sl = MemorySlice(mem, ["/a"], read_only=True) sl = MemorySlice(memory=mem, scopes=["/a"], read_only=True)
matches = sl.recall("scope", limit=5, depth="shallow") matches = sl.recall("scope", limit=5, depth="shallow")
assert isinstance(matches, list) assert isinstance(matches, list)
@@ -223,7 +223,7 @@ def test_memory_slice_remember_is_noop_when_read_only(tmp_path: Path, mock_embed
from crewai.memory.memory_scope import MemorySlice from crewai.memory.memory_scope import MemorySlice
mem = Memory(storage=str(tmp_path / "db7"), llm=MagicMock(), embedder=mock_embedder) mem = Memory(storage=str(tmp_path / "db7"), llm=MagicMock(), embedder=mock_embedder)
sl = MemorySlice(mem, ["/a"], read_only=True) sl = MemorySlice(memory=mem, scopes=["/a"], read_only=True)
result = sl.remember("x", scope="/a") result = sl.remember("x", scope="/a")
assert result is None assert result is None
assert mem.list_records() == [] assert mem.list_records() == []
@@ -319,7 +319,7 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
from crewai.agents.parser import AgentFinish from crewai.agents.parser import AgentFinish
mock_memory = MagicMock() mock_memory = MagicMock()
mock_memory._read_only = False mock_memory.read_only = False
mock_memory.extract_memories.return_value = ["Fact A.", "Fact B."] mock_memory.extract_memories.return_value = ["Fact A.", "Fact B."]
mock_agent = MagicMock() mock_agent = MagicMock()
@@ -360,7 +360,7 @@ def test_executor_save_to_memory_skips_delegation_output() -> None:
from crewai.utilities.string_utils import sanitize_tool_name from crewai.utilities.string_utils import sanitize_tool_name
mock_memory = MagicMock() mock_memory = MagicMock()
mock_memory._read_only = False mock_memory.read_only = False
mock_agent = MagicMock() mock_agent = MagicMock()
mock_agent.memory = mock_memory mock_agent.memory = mock_memory
mock_agent._logger = MagicMock() mock_agent._logger = MagicMock()
@@ -393,7 +393,7 @@ def test_memory_scope_extract_memories_delegates() -> None:
mock_memory = MagicMock() mock_memory = MagicMock()
mock_memory.extract_memories.return_value = ["Scoped fact."] mock_memory.extract_memories.return_value = ["Scoped fact."]
scope = MemoryScope(mock_memory, "/agent/1") scope = MemoryScope(memory=mock_memory, root_path="/agent/1")
result = scope.extract_memories("Some content") result = scope.extract_memories("Some content")
mock_memory.extract_memories.assert_called_once_with("Some content") mock_memory.extract_memories.assert_called_once_with("Some content")
assert result == ["Scoped fact."] assert result == ["Scoped fact."]
@@ -405,7 +405,7 @@ def test_memory_slice_extract_memories_delegates() -> None:
mock_memory = MagicMock() mock_memory = MagicMock()
mock_memory.extract_memories.return_value = ["Sliced fact."] mock_memory.extract_memories.return_value = ["Sliced fact."]
sl = MemorySlice(mock_memory, ["/a", "/b"], read_only=True) sl = MemorySlice(memory=mock_memory, scopes=["/a", "/b"], read_only=True)
result = sl.extract_memories("Some content") result = sl.extract_memories("Some content")
mock_memory.extract_memories.assert_called_once_with("Some content") mock_memory.extract_memories.assert_called_once_with("Some content")
assert result == ["Sliced fact."] assert result == ["Sliced fact."]
@@ -670,10 +670,10 @@ def test_agent_kickoff_memory_recall_and_save(tmp_path: Path, mock_embedder: Mag
verbose=False, verbose=False,
) )
# Mock recall to verify it's called, but return real results # Patch on the class to avoid Pydantic BaseModel __delattr__ restriction
with patch.object(mem, "recall", wraps=mem.recall) as recall_mock, \ with patch.object(Memory, "recall", wraps=mem.recall) as recall_mock, \
patch.object(mem, "extract_memories", return_value=["PostgreSQL is used."]) as extract_mock, \ patch.object(Memory, "extract_memories", return_value=["PostgreSQL is used."]) as extract_mock, \
patch.object(mem, "remember_many", wraps=mem.remember_many) as remember_many_mock: patch.object(Memory, "remember_many", wraps=mem.remember_many) as remember_many_mock:
result = agent.kickoff("What database do we use?") result = agent.kickoff("What database do we use?")
assert result is not None assert result is not None

View File

@@ -36,7 +36,7 @@ from crewai.flow import Flow, start
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM from crewai.llm import LLM
from crewai.memory.unified_memory import Memory
from crewai.process import Process from crewai.process import Process
from crewai.project import CrewBase, agent, before_kickoff, crew, task from crewai.project import CrewBase, agent, before_kickoff, crew, task
from crewai.task import Task from crewai.task import Task
@@ -2618,9 +2618,9 @@ def test_memory_remember_called_after_task():
) )
with patch.object( with patch.object(
crew._memory, "extract_memories", wraps=crew._memory.extract_memories Memory, "extract_memories", wraps=crew._memory.extract_memories
) as extract_mock, patch.object( ) as extract_mock, patch.object(
crew._memory, "remember", wraps=crew._memory.remember Memory, "remember", wraps=crew._memory.remember
) as remember_mock: ) as remember_mock:
crew.kickoff() crew.kickoff()
@@ -4773,13 +4773,13 @@ def test_memory_remember_receives_task_content():
# Mock extract_memories to return fake memories and capture the raw input. # Mock extract_memories to return fake memories and capture the raw input.
# No wraps= needed -- the test only checks what args it receives, not the output. # No wraps= needed -- the test only checks what args it receives, not the output.
patch.object( patch.object(
crew._memory, "extract_memories", return_value=["Fake memory."] Memory, "extract_memories", return_value=["Fake memory."]
) as extract_mock, ) as extract_mock,
# Mock recall to avoid LLM calls for query analysis (not in cassette). # Mock recall to avoid LLM calls for query analysis (not in cassette).
patch.object(crew._memory, "recall", return_value=[]), patch.object(Memory, "recall", return_value=[]),
# Mock remember_many to prevent the background save from triggering # Mock remember_many to prevent the background save from triggering
# LLM calls (field resolution) that aren't in the cassette. # LLM calls (field resolution) that aren't in the cassette.
patch.object(crew._memory, "remember_many", return_value=[]), patch.object(Memory, "remember_many", return_value=[]),
): ):
crew.kickoff() crew.kickoff()