mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-07 20:38:13 +00:00
Compare commits
1 Commits
gl/refacto
...
lg-proxy-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
497400f59d |
@@ -30,9 +30,12 @@ class CrewAgentExecutorMixin:
|
||||
memory = getattr(self.agent, "memory", None) or (
|
||||
getattr(self.crew, "_memory", None) if self.crew else None
|
||||
)
|
||||
if memory is None or not self.task or memory.read_only:
|
||||
if memory is None or not self.task or getattr(memory, "_read_only", False):
|
||||
return
|
||||
if f"Action: {sanitize_tool_name('Delegate work to coworker')}" in output.text:
|
||||
if (
|
||||
f"Action: {sanitize_tool_name('Delegate work to coworker')}"
|
||||
in output.text
|
||||
):
|
||||
return
|
||||
try:
|
||||
raw = (
|
||||
@@ -45,4 +48,6 @@ class CrewAgentExecutorMixin:
|
||||
if extracted:
|
||||
memory.remember_many(extracted, agent_role=self.agent.role)
|
||||
except Exception as e:
|
||||
self.agent._logger.log("error", f"Failed to save to memory: {e}")
|
||||
self.agent._logger.log(
|
||||
"error", f"Failed to save to memory: {e}"
|
||||
)
|
||||
|
||||
@@ -497,6 +497,50 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._list)
|
||||
|
||||
def index(self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None) -> int: # type: ignore[override]
|
||||
if stop is None:
|
||||
return self._list.index(value, start)
|
||||
return self._list.index(value, start, stop)
|
||||
|
||||
def count(self, value: T) -> int:
|
||||
return self._list.count(value)
|
||||
|
||||
def sort(self, *, key: Any = None, reverse: bool = False) -> None:
|
||||
with self._lock:
|
||||
self._list.sort(key=key, reverse=reverse)
|
||||
|
||||
def reverse(self) -> None:
|
||||
with self._lock:
|
||||
self._list.reverse()
|
||||
|
||||
def copy(self) -> list[T]:
|
||||
return self._list.copy()
|
||||
|
||||
def __add__(self, other: list[T]) -> list[T]:
|
||||
return self._list + other
|
||||
|
||||
def __radd__(self, other: list[T]) -> list[T]:
|
||||
return other + self._list
|
||||
|
||||
def __iadd__(self, other: Iterable[T]) -> list[T]:
|
||||
with self._lock:
|
||||
self._list += list(other)
|
||||
return self._list
|
||||
|
||||
def __mul__(self, n: SupportsIndex) -> list[T]:
|
||||
return self._list * n
|
||||
|
||||
def __rmul__(self, n: SupportsIndex) -> list[T]:
|
||||
return self._list * n
|
||||
|
||||
def __imul__(self, n: SupportsIndex) -> list[T]:
|
||||
with self._lock:
|
||||
self._list *= n
|
||||
return self._list
|
||||
|
||||
def __reversed__(self) -> Iterator[T]:
|
||||
return reversed(self._list)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare based on the underlying list contents."""
|
||||
if isinstance(other, LockedListProxy):
|
||||
@@ -579,6 +623,23 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self._dict)
|
||||
|
||||
def copy(self) -> dict[str, T]:
|
||||
return self._dict.copy()
|
||||
|
||||
def __or__(self, other: dict[str, T]) -> dict[str, T]:
|
||||
return self._dict | other
|
||||
|
||||
def __ror__(self, other: dict[str, T]) -> dict[str, T]:
|
||||
return other | self._dict
|
||||
|
||||
def __ior__(self, other: dict[str, T]) -> dict[str, T]:
|
||||
with self._lock:
|
||||
self._dict |= other
|
||||
return self._dict
|
||||
|
||||
def __reversed__(self) -> Iterator[str]:
|
||||
return reversed(self._dict)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Compare based on the underlying dict contents."""
|
||||
if isinstance(other, LockedDictProxy):
|
||||
|
||||
@@ -600,7 +600,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
def _save_to_memory(self, output_text: str) -> None:
|
||||
"""Extract discrete memories from the run and remember each. No-op if _memory is None or read-only."""
|
||||
if self._memory is None or self._memory.read_only:
|
||||
if self._memory is None or getattr(self._memory, "_read_only", False):
|
||||
return
|
||||
input_str = self._get_last_user_content() or "User request"
|
||||
try:
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
from crewai.memory.types import (
|
||||
_RECALL_OVERSAMPLE_FACTOR,
|
||||
@@ -13,36 +15,22 @@ from crewai.memory.types import (
|
||||
MemoryRecord,
|
||||
ScopeInfo,
|
||||
)
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
|
||||
class MemoryScope(BaseModel):
|
||||
class MemoryScope:
|
||||
"""View of Memory restricted to a root path. All operations are scoped under that path."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
def __init__(self, memory: Memory, root_path: str) -> None:
|
||||
"""Initialize scope.
|
||||
|
||||
root_path: str = Field(default="/")
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
_root: str = PrivateAttr()
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def _accept_memory(cls, data: Any, handler: Any) -> MemoryScope:
|
||||
"""Extract memory dependency and normalize root path before validation."""
|
||||
memory = data.pop("memory")
|
||||
instance: MemoryScope = handler(data)
|
||||
instance._memory = memory
|
||||
root = instance.root_path.rstrip("/") or ""
|
||||
if root and not root.startswith("/"):
|
||||
root = "/" + root
|
||||
instance._root = root
|
||||
return instance
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
"""Whether the underlying memory is read-only."""
|
||||
return self._memory.read_only
|
||||
Args:
|
||||
memory: The underlying Memory instance.
|
||||
root_path: Root path for this scope (e.g. /agent/1).
|
||||
"""
|
||||
self._memory = memory
|
||||
self._root = root_path.rstrip("/") or ""
|
||||
if self._root and not self._root.startswith("/"):
|
||||
self._root = "/" + self._root
|
||||
|
||||
def _scope_path(self, scope: str | None) -> str:
|
||||
if not scope or scope == "/":
|
||||
@@ -64,7 +52,7 @@ class MemoryScope(BaseModel):
|
||||
importance: float | None = None,
|
||||
source: str | None = None,
|
||||
private: bool = False,
|
||||
) -> MemoryRecord | None:
|
||||
) -> MemoryRecord:
|
||||
"""Remember content; scope is relative to this scope's root."""
|
||||
path = self._scope_path(scope)
|
||||
return self._memory.remember(
|
||||
@@ -83,7 +71,7 @@ class MemoryScope(BaseModel):
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
depth: Literal["shallow", "deep"] = "deep",
|
||||
depth: str = "deep",
|
||||
source: str | None = None,
|
||||
include_private: bool = False,
|
||||
) -> list[MemoryMatch]:
|
||||
@@ -150,32 +138,34 @@ class MemoryScope(BaseModel):
|
||||
"""Return a narrower scope under this scope."""
|
||||
child = path.strip("/")
|
||||
if not child:
|
||||
return MemoryScope(memory=self._memory, root_path=self._root or "/")
|
||||
return MemoryScope(self._memory, self._root or "/")
|
||||
base = self._root.rstrip("/") or ""
|
||||
new_root = f"{base}/{child}" if base else f"/{child}"
|
||||
return MemoryScope(memory=self._memory, root_path=new_root)
|
||||
return MemoryScope(self._memory, new_root)
|
||||
|
||||
|
||||
class MemorySlice(BaseModel):
|
||||
class MemorySlice:
|
||||
"""View over multiple scopes: recall searches all, remember is a no-op when read_only."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
def __init__(
|
||||
self,
|
||||
memory: Memory,
|
||||
scopes: list[str],
|
||||
categories: list[str] | None = None,
|
||||
read_only: bool = True,
|
||||
) -> None:
|
||||
"""Initialize slice.
|
||||
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
categories: list[str] | None = Field(default=None)
|
||||
read_only: bool = Field(default=True)
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def _accept_memory(cls, data: Any, handler: Any) -> MemorySlice:
|
||||
"""Extract memory dependency and normalize scopes before validation."""
|
||||
memory = data.pop("memory")
|
||||
data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])]
|
||||
instance: MemorySlice = handler(data)
|
||||
instance._memory = memory
|
||||
return instance
|
||||
Args:
|
||||
memory: The underlying Memory instance.
|
||||
scopes: List of scope paths to include.
|
||||
categories: Optional category filter for recall.
|
||||
read_only: If True, remember() is a silent no-op.
|
||||
"""
|
||||
self._memory = memory
|
||||
self._scopes = [s.rstrip("/") or "/" for s in scopes]
|
||||
self._categories = categories
|
||||
self._read_only = read_only
|
||||
|
||||
def remember(
|
||||
self,
|
||||
@@ -188,7 +178,7 @@ class MemorySlice(BaseModel):
|
||||
private: bool = False,
|
||||
) -> MemoryRecord | None:
|
||||
"""Remember into an explicit scope. No-op when read_only=True."""
|
||||
if self.read_only:
|
||||
if self._read_only:
|
||||
return None
|
||||
return self._memory.remember(
|
||||
content,
|
||||
@@ -206,14 +196,14 @@ class MemorySlice(BaseModel):
|
||||
scope: str | None = None,
|
||||
categories: list[str] | None = None,
|
||||
limit: int = 10,
|
||||
depth: Literal["shallow", "deep"] = "deep",
|
||||
depth: str = "deep",
|
||||
source: str | None = None,
|
||||
include_private: bool = False,
|
||||
) -> list[MemoryMatch]:
|
||||
"""Recall across all slice scopes; results merged and re-ranked."""
|
||||
cats = categories or self.categories
|
||||
cats = categories or self._categories
|
||||
all_matches: list[MemoryMatch] = []
|
||||
for sc in self.scopes:
|
||||
for sc in self._scopes:
|
||||
matches = self._memory.recall(
|
||||
query,
|
||||
scope=sc,
|
||||
@@ -241,7 +231,7 @@ class MemorySlice(BaseModel):
|
||||
def list_scopes(self, path: str = "/") -> list[str]:
|
||||
"""List scopes across all slice roots."""
|
||||
out: list[str] = []
|
||||
for sc in self.scopes:
|
||||
for sc in self._scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
out.extend(self._memory.list_scopes(full))
|
||||
return sorted(set(out))
|
||||
@@ -253,23 +243,15 @@ class MemorySlice(BaseModel):
|
||||
oldest: datetime | None = None
|
||||
newest: datetime | None = None
|
||||
children: list[str] = []
|
||||
for sc in self.scopes:
|
||||
for sc in self._scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
inf = self._memory.info(full)
|
||||
total_records += inf.record_count
|
||||
all_categories.update(inf.categories)
|
||||
if inf.oldest_record:
|
||||
oldest = (
|
||||
inf.oldest_record
|
||||
if oldest is None
|
||||
else min(oldest, inf.oldest_record)
|
||||
)
|
||||
oldest = inf.oldest_record if oldest is None else min(oldest, inf.oldest_record)
|
||||
if inf.newest_record:
|
||||
newest = (
|
||||
inf.newest_record
|
||||
if newest is None
|
||||
else max(newest, inf.newest_record)
|
||||
)
|
||||
newest = inf.newest_record if newest is None else max(newest, inf.newest_record)
|
||||
children.extend(inf.child_scopes)
|
||||
return ScopeInfo(
|
||||
path=path,
|
||||
@@ -283,7 +265,7 @@ class MemorySlice(BaseModel):
|
||||
def list_categories(self, path: str | None = None) -> dict[str, int]:
|
||||
"""Categories and counts across slice scopes."""
|
||||
counts: dict[str, int] = {}
|
||||
for sc in self.scopes:
|
||||
for sc in self._scopes:
|
||||
full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc
|
||||
for k, v in self._memory.list_categories(full).items():
|
||||
counts[k] = counts.get(k, 0) + v
|
||||
|
||||
@@ -6,9 +6,7 @@ from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PlainValidator, PrivateAttr
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
@@ -41,18 +39,13 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
def _passthrough(v: Any) -> Any:
|
||||
"""PlainValidator that accepts any value, bypassing strict union discrimination."""
|
||||
return v
|
||||
|
||||
|
||||
def _default_embedder() -> OpenAIEmbeddingFunction:
|
||||
"""Build default OpenAI embedder for memory."""
|
||||
spec: OpenAIProviderSpec = {"provider": "openai", "config": {}}
|
||||
return build_embedder(spec)
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
class Memory:
|
||||
"""Unified memory: standalone, LLM-analyzed, with intelligent recall flow.
|
||||
|
||||
Works without agent/crew. Uses LLM to infer scope, categories, importance on save.
|
||||
@@ -60,119 +53,116 @@ class Memory(BaseModel):
|
||||
pluggable storage (LanceDB default).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM | str = "gpt-4o-mini",
|
||||
storage: StorageBackend | str = "lancedb",
|
||||
embedder: Any = None,
|
||||
# -- Scoring weights --
|
||||
# These three weights control how recall results are ranked.
|
||||
# The composite score is: semantic_weight * similarity + recency_weight * decay + importance_weight * importance.
|
||||
# They should sum to ~1.0 for intuitive scoring.
|
||||
recency_weight: float = 0.3,
|
||||
semantic_weight: float = 0.5,
|
||||
importance_weight: float = 0.2,
|
||||
# How quickly old memories lose relevance. The recency score halves every
|
||||
# N days (exponential decay). Lower = faster forgetting; higher = longer relevance.
|
||||
recency_half_life_days: int = 30,
|
||||
# -- Consolidation --
|
||||
# When remembering new content, if an existing record has similarity >= this
|
||||
# threshold, the LLM is asked to merge/update/delete. Set to 1.0 to disable.
|
||||
consolidation_threshold: float = 0.85,
|
||||
# Max existing records to compare against when checking for consolidation.
|
||||
consolidation_limit: int = 5,
|
||||
# -- Save defaults --
|
||||
# Importance assigned to new memories when no explicit value is given and
|
||||
# the LLM analysis path is skipped (all fields provided by the caller).
|
||||
default_importance: float = 0.5,
|
||||
# -- Recall depth control --
|
||||
# These thresholds govern the RecallFlow router that decides between
|
||||
# returning results immediately ("synthesize") vs. doing an extra
|
||||
# LLM-driven exploration round ("explore_deeper").
|
||||
# confidence >= confidence_threshold_high => always synthesize
|
||||
# confidence < confidence_threshold_low => explore deeper (if budget > 0)
|
||||
# complex query + confidence < complex_query_threshold => explore deeper
|
||||
confidence_threshold_high: float = 0.8,
|
||||
confidence_threshold_low: float = 0.5,
|
||||
complex_query_threshold: float = 0.7,
|
||||
# How many LLM-driven exploration rounds the RecallFlow is allowed to run.
|
||||
# 0 = always shallow (vector search only); higher = more thorough but slower.
|
||||
exploration_budget: int = 1,
|
||||
# Queries shorter than this skip LLM analysis (saving ~1-3s).
|
||||
# Longer queries (full task descriptions) benefit from LLM distillation.
|
||||
query_analysis_threshold: int = 200,
|
||||
# When True, all write operations (remember, remember_many) are silently
|
||||
# skipped. Useful for sharing a read-only view of memory across agents
|
||||
# without any of them persisting new memories.
|
||||
read_only: bool = False,
|
||||
) -> None:
|
||||
"""Initialize Memory.
|
||||
|
||||
llm: Annotated[BaseLLM | str, PlainValidator(_passthrough)] = Field(
|
||||
default="gpt-4o-mini",
|
||||
description="LLM for analysis (model name or BaseLLM instance).",
|
||||
)
|
||||
storage: Annotated[StorageBackend | str, PlainValidator(_passthrough)] = Field(
|
||||
default="lancedb",
|
||||
description="Storage backend instance or path string.",
|
||||
)
|
||||
embedder: Any = Field(
|
||||
default=None,
|
||||
description="Embedding callable, provider config dict, or None for default OpenAI.",
|
||||
)
|
||||
recency_weight: float = Field(
|
||||
default=0.3,
|
||||
description="Weight for recency in the composite relevance score.",
|
||||
)
|
||||
semantic_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight for semantic similarity in the composite relevance score.",
|
||||
)
|
||||
importance_weight: float = Field(
|
||||
default=0.2,
|
||||
description="Weight for importance in the composite relevance score.",
|
||||
)
|
||||
recency_half_life_days: int = Field(
|
||||
default=30,
|
||||
description="Recency score halves every N days (exponential decay).",
|
||||
)
|
||||
consolidation_threshold: float = Field(
|
||||
default=0.85,
|
||||
description="Similarity above which consolidation is triggered on save.",
|
||||
)
|
||||
consolidation_limit: int = Field(
|
||||
default=5,
|
||||
description="Max existing records to compare during consolidation.",
|
||||
)
|
||||
default_importance: float = Field(
|
||||
default=0.5,
|
||||
description="Default importance when not provided or inferred.",
|
||||
)
|
||||
confidence_threshold_high: float = Field(
|
||||
default=0.8,
|
||||
description="Recall confidence above which results are returned directly.",
|
||||
)
|
||||
confidence_threshold_low: float = Field(
|
||||
default=0.5,
|
||||
description="Recall confidence below which deeper exploration is triggered.",
|
||||
)
|
||||
complex_query_threshold: float = Field(
|
||||
default=0.7,
|
||||
description="For complex queries, explore deeper below this confidence.",
|
||||
)
|
||||
exploration_budget: int = Field(
|
||||
default=1,
|
||||
description="Number of LLM-driven exploration rounds during deep recall.",
|
||||
)
|
||||
query_analysis_threshold: int = Field(
|
||||
default=200,
|
||||
description="Queries shorter than this skip LLM analysis during deep recall.",
|
||||
)
|
||||
read_only: bool = Field(
|
||||
default=False,
|
||||
description="If True, remember() and remember_many() are silent no-ops.",
|
||||
)
|
||||
|
||||
_config: MemoryConfig = PrivateAttr()
|
||||
_llm_instance: BaseLLM | None = PrivateAttr(default=None)
|
||||
_embedder_instance: Any = PrivateAttr(default=None)
|
||||
_storage: StorageBackend = PrivateAttr()
|
||||
_save_pool: ThreadPoolExecutor = PrivateAttr(
|
||||
default_factory=lambda: ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="memory-save"
|
||||
)
|
||||
)
|
||||
_pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list)
|
||||
_pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize runtime state from field values."""
|
||||
Args:
|
||||
llm: LLM for analysis (model name or BaseLLM instance).
|
||||
storage: Backend: "lancedb" or a StorageBackend instance.
|
||||
embedder: Embedding callable, provider config dict, or None (default OpenAI).
|
||||
recency_weight: Weight for recency in the composite relevance score.
|
||||
semantic_weight: Weight for semantic similarity in the composite relevance score.
|
||||
importance_weight: Weight for importance in the composite relevance score.
|
||||
recency_half_life_days: Recency score halves every N days (exponential decay).
|
||||
consolidation_threshold: Similarity above which consolidation is triggered on save.
|
||||
consolidation_limit: Max existing records to compare during consolidation.
|
||||
default_importance: Default importance when not provided or inferred.
|
||||
confidence_threshold_high: Recall confidence above which results are returned directly.
|
||||
confidence_threshold_low: Recall confidence below which deeper exploration is triggered.
|
||||
complex_query_threshold: For complex queries, explore deeper below this confidence.
|
||||
exploration_budget: Number of LLM-driven exploration rounds during deep recall.
|
||||
query_analysis_threshold: Queries shorter than this skip LLM analysis during deep recall.
|
||||
read_only: If True, remember() and remember_many() are silent no-ops.
|
||||
"""
|
||||
self._read_only = read_only
|
||||
self._config = MemoryConfig(
|
||||
recency_weight=self.recency_weight,
|
||||
semantic_weight=self.semantic_weight,
|
||||
importance_weight=self.importance_weight,
|
||||
recency_half_life_days=self.recency_half_life_days,
|
||||
consolidation_threshold=self.consolidation_threshold,
|
||||
consolidation_limit=self.consolidation_limit,
|
||||
default_importance=self.default_importance,
|
||||
confidence_threshold_high=self.confidence_threshold_high,
|
||||
confidence_threshold_low=self.confidence_threshold_low,
|
||||
complex_query_threshold=self.complex_query_threshold,
|
||||
exploration_budget=self.exploration_budget,
|
||||
query_analysis_threshold=self.query_analysis_threshold,
|
||||
recency_weight=recency_weight,
|
||||
semantic_weight=semantic_weight,
|
||||
importance_weight=importance_weight,
|
||||
recency_half_life_days=recency_half_life_days,
|
||||
consolidation_threshold=consolidation_threshold,
|
||||
consolidation_limit=consolidation_limit,
|
||||
default_importance=default_importance,
|
||||
confidence_threshold_high=confidence_threshold_high,
|
||||
confidence_threshold_low=confidence_threshold_low,
|
||||
complex_query_threshold=complex_query_threshold,
|
||||
exploration_budget=exploration_budget,
|
||||
query_analysis_threshold=query_analysis_threshold,
|
||||
)
|
||||
|
||||
self._llm_instance = None if isinstance(self.llm, str) else self.llm
|
||||
self._embedder_instance = (
|
||||
self.embedder
|
||||
if (self.embedder is not None and not isinstance(self.embedder, dict))
|
||||
# Store raw config for lazy initialization. LLM and embedder are only
|
||||
# built on first access so that Memory() never fails at construction
|
||||
# time (e.g. when auto-created by Flow without an API key set).
|
||||
self._llm_config: BaseLLM | str = llm
|
||||
self._llm_instance: BaseLLM | None = None if isinstance(llm, str) else llm
|
||||
self._embedder_config: Any = embedder
|
||||
self._embedder_instance: Any = (
|
||||
embedder
|
||||
if (embedder is not None and not isinstance(embedder, dict))
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(self.storage, str):
|
||||
if isinstance(storage, str):
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
self._storage = (
|
||||
LanceDBStorage()
|
||||
if self.storage == "lancedb"
|
||||
else LanceDBStorage(path=self.storage)
|
||||
)
|
||||
self._storage = LanceDBStorage() if storage == "lancedb" else LanceDBStorage(path=storage)
|
||||
else:
|
||||
self._storage = self.storage
|
||||
self._storage = storage
|
||||
|
||||
# Background save queue. max_workers=1 serializes saves to avoid
|
||||
# concurrent storage mutations (two saves finding the same similar
|
||||
# record and both trying to update/delete it). Within each save,
|
||||
# the parallel LLM calls still run on their own thread pool.
|
||||
self._save_pool = ThreadPoolExecutor(
|
||||
max_workers=1, thread_name_prefix="memory-save"
|
||||
)
|
||||
self._pending_saves: list[Future[Any]] = []
|
||||
self._pending_lock = threading.Lock()
|
||||
|
||||
_MEMORY_DOCS_URL = "https://docs.crewai.com/concepts/memory"
|
||||
|
||||
@@ -183,7 +173,11 @@ class Memory(BaseModel):
|
||||
from crewai.llm import LLM
|
||||
|
||||
try:
|
||||
model_name = self.llm if isinstance(self.llm, str) else str(self.llm)
|
||||
model_name = (
|
||||
self._llm_config
|
||||
if isinstance(self._llm_config, str)
|
||||
else str(self._llm_config)
|
||||
)
|
||||
self._llm_instance = LLM(model=model_name)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
@@ -203,8 +197,8 @@ class Memory(BaseModel):
|
||||
"""Lazy embedder initialization -- only created when first needed."""
|
||||
if self._embedder_instance is None:
|
||||
try:
|
||||
if isinstance(self.embedder, dict):
|
||||
self._embedder_instance = build_embedder(self.embedder)
|
||||
if isinstance(self._embedder_config, dict):
|
||||
self._embedder_instance = build_embedder(self._embedder_config)
|
||||
else:
|
||||
self._embedder_instance = _default_embedder()
|
||||
except Exception as e:
|
||||
@@ -362,7 +356,7 @@ class Memory(BaseModel):
|
||||
Raises:
|
||||
Exception: On save failure (events emitted).
|
||||
"""
|
||||
if self.read_only:
|
||||
if self._read_only:
|
||||
return None
|
||||
_source_type = "unified_memory"
|
||||
try:
|
||||
@@ -450,7 +444,7 @@ class Memory(BaseModel):
|
||||
Returns:
|
||||
Empty list (records are not available until the background save completes).
|
||||
"""
|
||||
if not contents or self.read_only:
|
||||
if not contents or self._read_only:
|
||||
return []
|
||||
|
||||
self._submit_save(
|
||||
|
||||
@@ -121,7 +121,7 @@ def create_memory_tools(memory: Any) -> list[BaseTool]:
|
||||
description=i18n.tools("recall_memory"),
|
||||
),
|
||||
]
|
||||
if not memory.read_only:
|
||||
if not getattr(memory, "_read_only", False):
|
||||
tools.append(
|
||||
RememberTool(
|
||||
memory=memory,
|
||||
|
||||
@@ -172,8 +172,8 @@ def test_memory_scope_slice(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
sc = mem.scope("/agent/1")
|
||||
assert sc._root in ("/agent/1", "/agent/1/")
|
||||
sl = mem.slice(["/a", "/b"], read_only=True)
|
||||
assert sl.read_only is True
|
||||
assert "/a" in sl.scopes and "/b" in sl.scopes
|
||||
assert sl._read_only is True
|
||||
assert "/a" in sl._scopes and "/b" in sl._scopes
|
||||
|
||||
|
||||
def test_memory_list_scopes_info_tree(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
@@ -198,7 +198,7 @@ def test_memory_scope_remember_recall(tmp_path: Path, mock_embedder: MagicMock)
|
||||
from crewai.memory.memory_scope import MemoryScope
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "db5"), llm=MagicMock(), embedder=mock_embedder)
|
||||
scope = MemoryScope(memory=mem, root_path="/crew/1")
|
||||
scope = MemoryScope(mem, "/crew/1")
|
||||
scope.remember("Scoped note", scope="/", categories=[], importance=0.5, metadata={})
|
||||
results = scope.recall("note", limit=5, depth="shallow")
|
||||
assert len(results) >= 1
|
||||
@@ -213,7 +213,7 @@ def test_memory_slice_recall(tmp_path: Path, mock_embedder: MagicMock) -> None:
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "db6"), llm=MagicMock(), embedder=mock_embedder)
|
||||
mem.remember("In scope A", scope="/a", categories=[], importance=0.5, metadata={})
|
||||
sl = MemorySlice(memory=mem, scopes=["/a"], read_only=True)
|
||||
sl = MemorySlice(mem, ["/a"], read_only=True)
|
||||
matches = sl.recall("scope", limit=5, depth="shallow")
|
||||
assert isinstance(matches, list)
|
||||
|
||||
@@ -223,7 +223,7 @@ def test_memory_slice_remember_is_noop_when_read_only(tmp_path: Path, mock_embed
|
||||
from crewai.memory.memory_scope import MemorySlice
|
||||
|
||||
mem = Memory(storage=str(tmp_path / "db7"), llm=MagicMock(), embedder=mock_embedder)
|
||||
sl = MemorySlice(memory=mem, scopes=["/a"], read_only=True)
|
||||
sl = MemorySlice(mem, ["/a"], read_only=True)
|
||||
result = sl.remember("x", scope="/a")
|
||||
assert result is None
|
||||
assert mem.list_records() == []
|
||||
@@ -319,7 +319,7 @@ def test_executor_save_to_memory_calls_extract_then_remember_per_item() -> None:
|
||||
from crewai.agents.parser import AgentFinish
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory.read_only = False
|
||||
mock_memory._read_only = False
|
||||
mock_memory.extract_memories.return_value = ["Fact A.", "Fact B."]
|
||||
|
||||
mock_agent = MagicMock()
|
||||
@@ -360,7 +360,7 @@ def test_executor_save_to_memory_skips_delegation_output() -> None:
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory.read_only = False
|
||||
mock_memory._read_only = False
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.memory = mock_memory
|
||||
mock_agent._logger = MagicMock()
|
||||
@@ -393,7 +393,7 @@ def test_memory_scope_extract_memories_delegates() -> None:
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory.extract_memories.return_value = ["Scoped fact."]
|
||||
scope = MemoryScope(memory=mock_memory, root_path="/agent/1")
|
||||
scope = MemoryScope(mock_memory, "/agent/1")
|
||||
result = scope.extract_memories("Some content")
|
||||
mock_memory.extract_memories.assert_called_once_with("Some content")
|
||||
assert result == ["Scoped fact."]
|
||||
@@ -405,7 +405,7 @@ def test_memory_slice_extract_memories_delegates() -> None:
|
||||
|
||||
mock_memory = MagicMock()
|
||||
mock_memory.extract_memories.return_value = ["Sliced fact."]
|
||||
sl = MemorySlice(memory=mock_memory, scopes=["/a", "/b"], read_only=True)
|
||||
sl = MemorySlice(mock_memory, ["/a", "/b"], read_only=True)
|
||||
result = sl.extract_memories("Some content")
|
||||
mock_memory.extract_memories.assert_called_once_with("Some content")
|
||||
assert result == ["Sliced fact."]
|
||||
@@ -670,10 +670,10 @@ def test_agent_kickoff_memory_recall_and_save(tmp_path: Path, mock_embedder: Mag
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Patch on the class to avoid Pydantic BaseModel __delattr__ restriction
|
||||
with patch.object(Memory, "recall", wraps=mem.recall) as recall_mock, \
|
||||
patch.object(Memory, "extract_memories", return_value=["PostgreSQL is used."]) as extract_mock, \
|
||||
patch.object(Memory, "remember_many", wraps=mem.remember_many) as remember_many_mock:
|
||||
# Mock recall to verify it's called, but return real results
|
||||
with patch.object(mem, "recall", wraps=mem.recall) as recall_mock, \
|
||||
patch.object(mem, "extract_memories", return_value=["PostgreSQL is used."]) as extract_mock, \
|
||||
patch.object(mem, "remember_many", wraps=mem.remember_many) as remember_many_mock:
|
||||
result = agent.kickoff("What database do we use?")
|
||||
|
||||
assert result is not None
|
||||
|
||||
@@ -1893,3 +1893,163 @@ def test_or_condition_self_listen_fires_once():
|
||||
flow = OrSelfListenFlow()
|
||||
flow.kickoff()
|
||||
assert call_count == 1
|
||||
|
||||
class ListState(BaseModel):
|
||||
items: list = []
|
||||
|
||||
|
||||
class DictState(BaseModel):
|
||||
data: dict = {}
|
||||
|
||||
|
||||
class _ListFlow(Flow[ListState]):
|
||||
@start()
|
||||
def populate(self):
|
||||
self.state.items = [3, 1, 4, 1, 5, 9, 2, 6]
|
||||
|
||||
|
||||
class _DictFlow(Flow[DictState]):
|
||||
@start()
|
||||
def populate(self):
|
||||
self.state.data = {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
def _make_list_flow():
|
||||
flow = _ListFlow()
|
||||
flow.kickoff()
|
||||
return flow
|
||||
|
||||
|
||||
def _make_dict_flow():
|
||||
flow = _DictFlow()
|
||||
flow.kickoff()
|
||||
return flow
|
||||
|
||||
|
||||
def test_locked_list_proxy_index():
|
||||
flow = _make_list_flow()
|
||||
assert flow.state.items.index(4) == 2
|
||||
assert flow.state.items.index(1, 2) == 3
|
||||
|
||||
|
||||
def test_locked_list_proxy_index_missing_raises():
|
||||
flow = _make_list_flow()
|
||||
with pytest.raises(ValueError):
|
||||
flow.state.items.index(999)
|
||||
|
||||
|
||||
def test_locked_list_proxy_count():
|
||||
flow = _make_list_flow()
|
||||
assert flow.state.items.count(1) == 2
|
||||
assert flow.state.items.count(999) == 0
|
||||
|
||||
|
||||
def test_locked_list_proxy_sort():
|
||||
flow = _make_list_flow()
|
||||
flow.state.items.sort()
|
||||
assert list(flow.state.items) == [1, 1, 2, 3, 4, 5, 6, 9]
|
||||
|
||||
|
||||
def test_locked_list_proxy_sort_reverse():
|
||||
flow = _make_list_flow()
|
||||
flow.state.items.sort(reverse=True)
|
||||
assert list(flow.state.items) == [9, 6, 5, 4, 3, 2, 1, 1]
|
||||
|
||||
|
||||
def test_locked_list_proxy_sort_key():
|
||||
flow = _make_list_flow()
|
||||
flow.state.items.sort(key=lambda x: -x)
|
||||
assert list(flow.state.items) == [9, 6, 5, 4, 3, 2, 1, 1]
|
||||
|
||||
|
||||
def test_locked_list_proxy_reverse():
|
||||
flow = _make_list_flow()
|
||||
original = list(flow.state.items)
|
||||
flow.state.items.reverse()
|
||||
assert list(flow.state.items) == list(reversed(original))
|
||||
|
||||
|
||||
def test_locked_list_proxy_copy():
|
||||
flow = _make_list_flow()
|
||||
copied = flow.state.items.copy()
|
||||
assert copied == [3, 1, 4, 1, 5, 9, 2, 6]
|
||||
assert isinstance(copied, list)
|
||||
copied.append(999)
|
||||
assert 999 not in flow.state.items
|
||||
|
||||
|
||||
def test_locked_list_proxy_add():
|
||||
flow = _make_list_flow()
|
||||
result = flow.state.items + [10, 11]
|
||||
assert result == [3, 1, 4, 1, 5, 9, 2, 6, 10, 11]
|
||||
assert len(flow.state.items) == 8
|
||||
|
||||
|
||||
def test_locked_list_proxy_radd():
|
||||
flow = _make_list_flow()
|
||||
result = [0] + flow.state.items
|
||||
assert result[0] == 0
|
||||
assert len(result) == 9
|
||||
|
||||
|
||||
def test_locked_list_proxy_iadd():
|
||||
flow = _make_list_flow()
|
||||
flow.state.items += [10]
|
||||
assert 10 in flow.state.items
|
||||
# Verify no deadlock: mutations must still work after +=
|
||||
flow.state.items.append(99)
|
||||
assert 99 in flow.state.items
|
||||
|
||||
|
||||
def test_locked_list_proxy_mul():
|
||||
flow = _make_list_flow()
|
||||
result = flow.state.items * 2
|
||||
assert len(result) == 16
|
||||
|
||||
|
||||
def test_locked_list_proxy_rmul():
|
||||
flow = _make_list_flow()
|
||||
result = 2 * flow.state.items
|
||||
assert len(result) == 16
|
||||
|
||||
|
||||
def test_locked_list_proxy_reversed():
|
||||
flow = _make_list_flow()
|
||||
original = list(flow.state.items)
|
||||
assert list(reversed(flow.state.items)) == list(reversed(original))
|
||||
|
||||
|
||||
def test_locked_dict_proxy_copy():
|
||||
flow = _make_dict_flow()
|
||||
copied = flow.state.data.copy()
|
||||
assert copied == {"a": 1, "b": 2, "c": 3}
|
||||
assert isinstance(copied, dict)
|
||||
copied["z"] = 99
|
||||
assert "z" not in flow.state.data
|
||||
|
||||
|
||||
def test_locked_dict_proxy_or():
|
||||
flow = _make_dict_flow()
|
||||
result = flow.state.data | {"d": 4}
|
||||
assert result == {"a": 1, "b": 2, "c": 3, "d": 4}
|
||||
assert "d" not in flow.state.data
|
||||
|
||||
|
||||
def test_locked_dict_proxy_ror():
|
||||
flow = _make_dict_flow()
|
||||
result = {"z": 0} | flow.state.data
|
||||
assert result == {"z": 0, "a": 1, "b": 2, "c": 3}
|
||||
|
||||
|
||||
def test_locked_dict_proxy_ior():
|
||||
flow = _make_dict_flow()
|
||||
flow.state.data |= {"d": 4}
|
||||
assert flow.state.data["d"] == 4
|
||||
# Verify no deadlock: mutations must still work after |=
|
||||
flow.state.data["e"] = 5
|
||||
assert flow.state.data["e"] == 5
|
||||
|
||||
|
||||
def test_locked_dict_proxy_reversed():
|
||||
flow = _make_dict_flow()
|
||||
assert list(reversed(flow.state.data)) == ["c", "b", "a"]
|
||||
|
||||
Reference in New Issue
Block a user