feat: add async ops to memory feat; create tests

This commit is contained in:
Greyson Lalonde
2025-12-02 13:09:52 -05:00
parent 132b6b224a
commit 441591d592
9 changed files with 1456 additions and 38 deletions

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from crewai.memory import (
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
class ContextualMemory:
"""Aggregates and retrieves context from multiple memory sources."""
def __init__(
self,
stm: ShortTermMemory,
@@ -46,9 +49,14 @@ class ContextualMemory:
self.exm.task = self.task
def build_context_for_task(self, task: Task, context: str) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
"""Build contextual information for a task synchronously.
Args:
task: The task to build context for.
context: Additional context string.
Returns:
Formatted context string from all memory sources.
"""
query = f"{task.description} {context}".strip()
@@ -63,6 +71,31 @@ class ContextualMemory:
]
return "\n".join(filter(None, context_parts))
async def abuild_context_for_task(self, task: Task, context: str) -> str:
"""Build contextual information for a task asynchronously.
Args:
task: The task to build context for.
context: Additional context string.
Returns:
Formatted context string from all memory sources.
"""
query = f"{task.description} {context}".strip()
if query == "":
return ""
# Fetch all contexts concurrently
results = await asyncio.gather(
self._afetch_ltm_context(task.description),
self._afetch_stm_context(query),
self._afetch_entity_context(query),
self._afetch_external_context(query),
)
return "\n".join(filter(None, results))
def _fetch_stm_context(self, query: str) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
@@ -135,3 +168,87 @@ class ContextualMemory:
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"
async def _afetch_stm_context(self, query: str) -> str:
"""Fetch recent relevant insights from STM asynchronously.
Args:
query: The search query.
Returns:
Formatted insights as bullet points, or empty string if none found.
"""
if self.stm is None:
return ""
stm_results = await self.stm.asearch(query)
formatted_results = "\n".join(
[f"- {result['content']}" for result in stm_results]
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
async def _afetch_ltm_context(self, task: str) -> str | None:
"""Fetch historical data from LTM asynchronously.
Args:
task: The task description to search for.
Returns:
Formatted historical data as bullet points, or None if none found.
"""
if self.ltm is None:
return ""
ltm_results = await self.ltm.asearch(task, latest_n=2)
if not ltm_results:
return None
formatted_results = [
suggestion
for result in ltm_results
for suggestion in result["metadata"]["suggestions"]
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
async def _afetch_entity_context(self, query: str) -> str:
"""Fetch relevant entity information asynchronously.
Args:
query: The search query.
Returns:
Formatted entity information as bullet points, or empty string if none found.
"""
if self.em is None:
return ""
em_results = await self.em.asearch(query)
formatted_results = "\n".join(
[f"- {result['content']}" for result in em_results]
)
return f"Entities:\n{formatted_results}" if em_results else ""
async def _afetch_external_context(self, query: str) -> str:
"""Fetch relevant information from External Memory asynchronously.
Args:
query: The search query.
Returns:
Formatted information as bullet points, or empty string if none found.
"""
if self.exm is None:
return ""
external_memories = await self.exm.asearch(query)
if not external_memories:
return ""
formatted_memories = "\n".join(
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"

View File

@@ -26,7 +26,13 @@ class EntityMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self,
crew: Any = None,
embedder_config: Any = None,
storage: Any = None,
path: str | None = None,
) -> None:
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
@@ -43,7 +49,7 @@ class EntityMemory(Memory):
if embedder_config and isinstance(embedder_config, dict)
else None
)
storage = Mem0Storage(type="short_term", crew=crew, config=config)
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
else:
storage = (
storage
@@ -170,7 +176,17 @@ class EntityMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search entity memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -217,6 +233,168 @@ class EntityMemory(Memory):
)
raise
async def asave(
self,
value: EntityMemoryItem | list[EntityMemoryItem],
metadata: dict[str, Any] | None = None,
) -> None:
"""Save entity items asynchronously.
Args:
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
metadata: Optional metadata dict (not used, for signature compatibility).
"""
if not value:
return
items = value if isinstance(value, list) else [value]
is_batch = len(items) > 1
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
metadata=metadata,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
saved_count = 0
errors: list[str | None] = []
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
"""Save a single item asynchronously."""
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
await super(EntityMemory, self).asave(data, item.metadata)
return True, None
except Exception as e:
return False, f"{item.name}: {e!s}"
try:
for item in items:
success, error = await save_single_item(item)
if success:
saved_count += 1
else:
errors.append(error)
if is_batch:
emit_value = f"Saved {saved_count} entities"
metadata = {"entity_count": saved_count, "errors": errors}
else:
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
metadata = items[0].metadata
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=emit_value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
if errors:
raise Exception(
f"Partial save: {len(errors)} failed out of {len(items)}"
)
except Exception as e:
fail_metadata = (
{"entity_count": len(items), "saved": saved_count}
if is_batch
else items[0].metadata
)
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
metadata=fail_metadata,
error=str(e),
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search entity memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await super().asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="entity_memory",
),
)
raise
def reset(self) -> None:
try:
self.storage.reset()

View File

@@ -30,7 +30,7 @@ class ExternalMemory(Memory):
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
@staticmethod
def external_supported_storages() -> dict[str, Any]:
@@ -53,7 +53,10 @@ class ExternalMemory(Memory):
if provider not in supported_storages:
raise ValueError(f"Provider {provider} not supported")
return supported_storages[provider](crew, embedder_config.get("config", {}))
storage: Storage = supported_storages[provider](
crew, embedder_config.get("config", {})
)
return storage
def save(
self,
@@ -111,7 +114,17 @@ class ExternalMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search external memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -158,6 +171,124 @@ class ExternalMemory(Memory):
)
raise
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to external memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
item = ExternalMemoryItem(
value=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
await super().asave(value=item.value, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
error=str(e),
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search external memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await super().asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="external_memory",
),
)
raise
def reset(self) -> None:
self.storage.reset()

View File

@@ -24,7 +24,11 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances.
"""
def __init__(self, storage=None, path=None):
def __init__(
self,
storage: LTMSQLiteStorage | None = None,
path: str | None = None,
) -> None:
if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage)
@@ -48,7 +52,7 @@ class LongTermMemory(Memory):
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
self.storage.save(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
@@ -80,11 +84,20 @@ class LongTermMemory(Memory):
)
raise
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
def search( # type: ignore[override]
self,
task: str,
latest_n: int = 3,
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
) -> list[dict[str, Any]]:
"""Search long-term memory for relevant entries.
Args:
task: The task description to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -98,7 +111,7 @@ class LongTermMemory(Memory):
start_time = time.time()
try:
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
results = self.storage.load(task, latest_n)
crewai_event_bus.emit(
self,
@@ -113,7 +126,118 @@ class LongTermMemory(Memory):
),
)
return results
return results or []
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=task,
limit=latest_n,
error=str(e),
source_type="long_term_memory",
),
)
raise
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
"""Save an item to long-term memory asynchronously.
Args:
item: The LongTermMemoryItem to save.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
metadata = item.metadata
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
await self.storage.asave(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
error=str(e),
source_type="long_term_memory",
),
)
raise
async def asearch( # type: ignore[override]
self,
task: str,
latest_n: int = 3,
) -> list[dict[str, Any]]:
"""Search long-term memory asynchronously.
Args:
task: The task description to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=task,
limit=latest_n,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await self.storage.aload(task, latest_n)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=task,
results=results,
limit=latest_n,
query_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results or []
except Exception as e:
crewai_event_bus.emit(
self,
@@ -127,4 +251,5 @@ class LongTermMemory(Memory):
raise
def reset(self) -> None:
"""Reset long-term memory."""
self.storage.reset()

View File

@@ -13,9 +13,7 @@ if TYPE_CHECKING:
class Memory(BaseModel):
"""
Base class for memory, now supporting agent tags and generic metadata.
"""
"""Base class for memory, supporting agent tags and generic metadata."""
embedder_config: EmbedderConfig | dict[str, Any] | None = None
crew: Any | None = None
@@ -52,20 +50,72 @@ class Memory(BaseModel):
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
metadata = metadata or {}
"""Save a value to memory.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
metadata = metadata or {}
self.storage.save(value, metadata)
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
metadata = metadata or {}
await self.storage.asave(value, metadata)
def search(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
return self.storage.search(
"""Search memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
results: list[Any] = self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
)
return results
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search memory for relevant entries asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
results: list[Any] = await self.storage.asearch(
query=query, limit=limit, score_threshold=score_threshold
)
return results
def set_crew(self, crew: Any) -> Memory:
"""Set the crew for this memory instance."""
self.crew = crew
return self

View File

@@ -30,7 +30,13 @@ class ShortTermMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self,
crew: Any = None,
embedder_config: Any = None,
storage: Any = None,
path: str | None = None,
) -> None:
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
@@ -47,7 +53,7 @@ class ShortTermMemory(Memory):
if embedder_config and isinstance(embedder_config, dict)
else None
)
storage = Mem0Storage(type="short_term", crew=crew, config=config)
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
else:
storage = (
storage
@@ -123,7 +129,17 @@ class ShortTermMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search short-term memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -140,7 +156,7 @@ class ShortTermMemory(Memory):
try:
results = self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
)
crewai_event_bus.emit(
self,
@@ -156,7 +172,130 @@ class ShortTermMemory(Memory):
),
)
return results
return list(results)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="short_term_memory",
),
)
raise
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to short-term memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
item = ShortTermMemoryItem(
data=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
if self._memory_provider == "mem0":
item.data = (
f"Remember the following insights from Agent run: {item.data}"
)
await super().asave(value=item.data, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
error=str(e),
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search short-term memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await self.storage.asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return list(results)
except Exception as e:
crewai_event_bus.emit(
self,

View File

@@ -3,29 +3,30 @@ from pathlib import Path
import sqlite3
from typing import Any
import aiosqlite
from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path
class LTMSQLiteStorage:
"""
An updated SQLite storage class for LTM data storage.
"""
"""SQLite storage class for long-term memory data."""
def __init__(self, db_path: str | None = None) -> None:
"""Initialize the SQLite storage.
Args:
db_path: Optional path to the database file.
"""
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
self.db_path = db_path
self._printer: Printer = Printer()
# Ensure parent directory exists
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._initialize_db()
def _initialize_db(self):
"""
Initializes the SQLite database and creates LTM table
"""
def _initialize_db(self) -> None:
"""Initialize the SQLite database and create LTM table."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
@@ -106,9 +107,7 @@ class LTMSQLiteStorage:
)
return None
def reset(
self,
) -> None:
def reset(self) -> None:
"""Resets the LTM table with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -121,4 +120,87 @@ class LTMSQLiteStorage:
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return
async def asave(
self,
task_description: str,
metadata: dict[str, Any],
datetime: str,
score: int | float,
) -> None:
"""Save data to the LTM table asynchronously.
Args:
task_description: Description of the task.
metadata: Metadata associated with the memory.
datetime: Timestamp of the memory.
score: Quality score of the memory.
"""
try:
async with aiosqlite.connect(self.db_path) as conn:
await conn.execute(
"""
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
VALUES (?, ?, ?, ?)
""",
(task_description, json.dumps(metadata), datetime, score),
)
await conn.commit()
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
color="red",
)
async def aload(
self, task_description: str, latest_n: int
) -> list[dict[str, Any]] | None:
"""Query the LTM table by task description asynchronously.
Args:
task_description: Description of the task to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries or None if error occurs.
"""
try:
async with aiosqlite.connect(self.db_path) as conn:
cursor = await conn.execute(
f"""
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
ORDER BY datetime DESC, score ASC
LIMIT {latest_n}
""", # nosec # noqa: S608
(task_description,),
)
rows = await cursor.fetchall()
if rows:
return [
{
"metadata": json.loads(row[0]),
"datetime": row[1],
"score": row[2],
}
for row in rows
]
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
color="red",
)
return None
async def areset(self) -> None:
"""Reset the LTM table asynchronously."""
try:
async with aiosqlite.connect(self.db_path) as conn:
await conn.execute("DELETE FROM long_term_memories")
await conn.commit()
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)

View File

@@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value to storage.
Args:
value: The value to save.
metadata: Metadata to associate with the value.
"""
try:
client = self._get_client()
collection_name = (
@@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage):
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
)
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value to storage asynchronously.
Args:
value: The value to save.
metadata: Metadata to associate with the value.
"""
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
await client.aget_or_create_collection(collection_name=collection_name)
document: BaseRecord = {"content": value}
if metadata:
document["metadata"] = metadata
batch_size = None
if (
self.embedder_config
and isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
await client.aadd_documents(
collection_name=collection_name,
documents=[document],
batch_size=cast(int, batch_size),
)
else:
await client.aadd_documents(
collection_name=collection_name, documents=[document]
)
except Exception as e:
logging.error(
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
)
def search(
self,
query: str,
@@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage):
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for matching entries in storage.
Args:
query: The search query.
limit: Maximum number of results to return.
filter: Optional metadata filter.
score_threshold: Minimum similarity score for results.
Returns:
List of matching entries.
"""
try:
client = self._get_client()
collection_name = (
@@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage):
)
return []
async def asearch(
self,
query: str,
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for matching entries in storage asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
filter: Optional metadata filter.
score_threshold: Minimum similarity score for results.
Returns:
List of matching entries.
"""
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
return await client.asearch(
collection_name=collection_name,
query=query,
limit=limit,
metadata_filter=filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
)
return []
def reset(self) -> None:
try:
client = self._get_client()

View File

@@ -0,0 +1,496 @@
"""Tests for async memory operations."""
import threading
from collections import defaultdict
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.memory_events import (
MemoryQueryCompletedEvent,
MemoryQueryStartedEvent,
MemorySaveCompletedEvent,
MemorySaveStartedEvent,
)
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.external.external_memory import ExternalMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.task import Task
@pytest.fixture
def mock_agent():
"""Fixture to create a mock agent."""
return Agent(
role="Researcher",
goal="Search relevant data and provide results",
backstory="You are a researcher at a leading tech think tank.",
tools=[],
verbose=True,
)
@pytest.fixture
def mock_task(mock_agent):
"""Fixture to create a mock task."""
return Task(
description="Perform a search on specific topics.",
expected_output="A list of relevant URLs based on the search query.",
agent=mock_agent,
)
@pytest.fixture
def short_term_memory(mock_agent, mock_task):
"""Fixture to create a ShortTermMemory instance."""
return ShortTermMemory(crew=Crew(agents=[mock_agent], tasks=[mock_task]))
@pytest.fixture
def long_term_memory(tmp_path):
"""Fixture to create a LongTermMemory instance."""
db_path = str(tmp_path / "test_ltm.db")
return LongTermMemory(path=db_path)
@pytest.fixture
def entity_memory(tmp_path, mock_agent, mock_task):
"""Fixture to create an EntityMemory instance."""
return EntityMemory(
crew=Crew(agents=[mock_agent], tasks=[mock_task]),
path=str(tmp_path / "test_entities"),
)
class TestAsyncShortTermMemory:
"""Tests for async ShortTermMemory operations."""
@pytest.mark.asyncio
async def test_asave_emits_events(self, short_term_memory):
"""Test that asave emits the correct events."""
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
await short_term_memory.asave(
value="async test value",
metadata={"task": "async_test_task"},
)
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveStartedEvent"]) >= 1
and len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=5,
)
assert success, "Timeout waiting for async save events"
assert len(events["MemorySaveStartedEvent"]) >= 1
assert len(events["MemorySaveCompletedEvent"]) >= 1
assert events["MemorySaveStartedEvent"][-1].value == "async test value"
assert events["MemorySaveStartedEvent"][-1].source_type == "short_term_memory"
@pytest.mark.asyncio
async def test_asearch_emits_events(self, short_term_memory):
"""Test that asearch emits the correct events."""
events: dict[str, list] = defaultdict(list)
search_started = threading.Event()
search_completed = threading.Event()
with patch.object(short_term_memory.storage, "asearch", new_callable=AsyncMock, return_value=[]):
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
search_started.set()
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
search_completed.set()
await short_term_memory.asearch(
query="async test query",
limit=3,
score_threshold=0.35,
)
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
assert len(events["MemoryQueryStartedEvent"]) >= 1
assert len(events["MemoryQueryCompletedEvent"]) >= 1
assert events["MemoryQueryStartedEvent"][-1].query == "async test query"
assert events["MemoryQueryStartedEvent"][-1].source_type == "short_term_memory"
class TestAsyncLongTermMemory:
"""Tests for async LongTermMemory operations."""
@pytest.mark.asyncio
async def test_asave_emits_events(self, long_term_memory):
"""Test that asave emits the correct events."""
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
item = LongTermMemoryItem(
task="async test task",
agent="test_agent",
expected_output="test output",
datetime="2024-01-01T00:00:00",
quality=0.9,
metadata={"task": "async test task", "quality": 0.9},
)
await long_term_memory.asave(item)
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveStartedEvent"]) >= 1
and len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=5,
)
assert success, "Timeout waiting for async save events"
assert len(events["MemorySaveStartedEvent"]) >= 1
assert len(events["MemorySaveCompletedEvent"]) >= 1
assert events["MemorySaveStartedEvent"][-1].source_type == "long_term_memory"
@pytest.mark.asyncio
async def test_asearch_emits_events(self, long_term_memory):
"""Test that asearch emits the correct events."""
events: dict[str, list] = defaultdict(list)
search_started = threading.Event()
search_completed = threading.Event()
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
search_started.set()
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
search_completed.set()
await long_term_memory.asearch(task="async test task", latest_n=3)
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
assert len(events["MemoryQueryStartedEvent"]) >= 1
assert len(events["MemoryQueryCompletedEvent"]) >= 1
assert events["MemoryQueryStartedEvent"][-1].source_type == "long_term_memory"
@pytest.mark.asyncio
async def test_asave_and_asearch_integration(self, long_term_memory):
"""Test that asave followed by asearch works correctly."""
item = LongTermMemoryItem(
task="integration test task",
agent="test_agent",
expected_output="test output",
datetime="2024-01-01T00:00:00",
quality=0.9,
metadata={"task": "integration test task", "quality": 0.9},
)
await long_term_memory.asave(item)
results = await long_term_memory.asearch(task="integration test task", latest_n=1)
assert results is not None
assert len(results) == 1
assert results[0]["metadata"]["agent"] == "test_agent"
class TestAsyncEntityMemory:
"""Tests for async EntityMemory operations."""
@pytest.mark.asyncio
async def test_asave_single_item_emits_events(self, entity_memory):
"""Test that asave with a single item emits the correct events."""
events: dict[str, list] = defaultdict(list)
condition = threading.Condition()
@crewai_event_bus.on(MemorySaveStartedEvent)
def on_save_started(source, event):
with condition:
events["MemorySaveStartedEvent"].append(event)
condition.notify()
@crewai_event_bus.on(MemorySaveCompletedEvent)
def on_save_completed(source, event):
with condition:
events["MemorySaveCompletedEvent"].append(event)
condition.notify()
item = EntityMemoryItem(
name="TestEntity",
type="Person",
description="A test entity for async operations",
relationships="Related to other test entities",
)
await entity_memory.asave(item)
with condition:
success = condition.wait_for(
lambda: len(events["MemorySaveStartedEvent"]) >= 1
and len(events["MemorySaveCompletedEvent"]) >= 1,
timeout=5,
)
assert success, "Timeout waiting for async save events"
assert len(events["MemorySaveStartedEvent"]) >= 1
assert len(events["MemorySaveCompletedEvent"]) >= 1
assert events["MemorySaveStartedEvent"][-1].source_type == "entity_memory"
@pytest.mark.asyncio
async def test_asearch_emits_events(self, entity_memory):
"""Test that asearch emits the correct events."""
events: dict[str, list] = defaultdict(list)
search_started = threading.Event()
search_completed = threading.Event()
@crewai_event_bus.on(MemoryQueryStartedEvent)
def on_search_started(source, event):
events["MemoryQueryStartedEvent"].append(event)
search_started.set()
@crewai_event_bus.on(MemoryQueryCompletedEvent)
def on_search_completed(source, event):
events["MemoryQueryCompletedEvent"].append(event)
search_completed.set()
await entity_memory.asearch(query="TestEntity", limit=5, score_threshold=0.6)
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
assert len(events["MemoryQueryStartedEvent"]) >= 1
assert len(events["MemoryQueryCompletedEvent"]) >= 1
assert events["MemoryQueryStartedEvent"][-1].source_type == "entity_memory"
class TestAsyncContextualMemory:
"""Tests for async ContextualMemory operations."""
@pytest.mark.asyncio
async def test_abuild_context_for_task_with_empty_query(self, mock_task):
"""Test that abuild_context_for_task returns empty string for empty query."""
mock_task.description = ""
contextual_memory = ContextualMemory(
stm=None,
ltm=None,
em=None,
exm=None,
)
result = await contextual_memory.abuild_context_for_task(mock_task, "")
assert result == ""
@pytest.mark.asyncio
async def test_abuild_context_for_task_with_none_memories(self, mock_task):
"""Test that abuild_context_for_task handles None memory sources."""
contextual_memory = ContextualMemory(
stm=None,
ltm=None,
em=None,
exm=None,
)
result = await contextual_memory.abuild_context_for_task(mock_task, "some context")
assert result == ""
@pytest.mark.asyncio
async def test_abuild_context_for_task_aggregates_results(self, mock_agent, mock_task):
"""Test that abuild_context_for_task aggregates results from all memory sources."""
mock_stm = MagicMock(spec=ShortTermMemory)
mock_stm.asearch = AsyncMock(return_value=[{"content": "STM insight"}])
mock_ltm = MagicMock(spec=LongTermMemory)
mock_ltm.asearch = AsyncMock(
return_value=[{"metadata": {"suggestions": ["LTM suggestion"]}}]
)
mock_em = MagicMock(spec=EntityMemory)
mock_em.asearch = AsyncMock(return_value=[{"content": "Entity info"}])
mock_exm = MagicMock(spec=ExternalMemory)
mock_exm.asearch = AsyncMock(return_value=[{"content": "External memory"}])
contextual_memory = ContextualMemory(
stm=mock_stm,
ltm=mock_ltm,
em=mock_em,
exm=mock_exm,
agent=mock_agent,
task=mock_task,
)
result = await contextual_memory.abuild_context_for_task(mock_task, "additional context")
assert "Recent Insights:" in result
assert "STM insight" in result
assert "Historical Data:" in result
assert "LTM suggestion" in result
assert "Entities:" in result
assert "Entity info" in result
assert "External memories:" in result
assert "External memory" in result
@pytest.mark.asyncio
async def test_afetch_stm_context_returns_formatted_results(self, mock_agent, mock_task):
"""Test that _afetch_stm_context returns properly formatted results."""
mock_stm = MagicMock(spec=ShortTermMemory)
mock_stm.asearch = AsyncMock(
return_value=[
{"content": "First insight"},
{"content": "Second insight"},
]
)
contextual_memory = ContextualMemory(
stm=mock_stm,
ltm=None,
em=None,
exm=None,
)
result = await contextual_memory._afetch_stm_context("test query")
assert "Recent Insights:" in result
assert "- First insight" in result
assert "- Second insight" in result
@pytest.mark.asyncio
async def test_afetch_ltm_context_returns_formatted_results(self, mock_agent, mock_task):
"""Test that _afetch_ltm_context returns properly formatted results."""
mock_ltm = MagicMock(spec=LongTermMemory)
mock_ltm.asearch = AsyncMock(
return_value=[
{"metadata": {"suggestions": ["Suggestion 1", "Suggestion 2"]}},
]
)
contextual_memory = ContextualMemory(
stm=None,
ltm=mock_ltm,
em=None,
exm=None,
)
result = await contextual_memory._afetch_ltm_context("test task")
assert "Historical Data:" in result
assert "- Suggestion 1" in result
assert "- Suggestion 2" in result
@pytest.mark.asyncio
async def test_afetch_entity_context_returns_formatted_results(self, mock_agent, mock_task):
"""Test that _afetch_entity_context returns properly formatted results."""
mock_em = MagicMock(spec=EntityMemory)
mock_em.asearch = AsyncMock(
return_value=[
{"content": "Entity A details"},
{"content": "Entity B details"},
]
)
contextual_memory = ContextualMemory(
stm=None,
ltm=None,
em=mock_em,
exm=None,
)
result = await contextual_memory._afetch_entity_context("test query")
assert "Entities:" in result
assert "- Entity A details" in result
assert "- Entity B details" in result
@pytest.mark.asyncio
async def test_afetch_external_context_returns_formatted_results(self):
"""Test that _afetch_external_context returns properly formatted results."""
mock_exm = MagicMock(spec=ExternalMemory)
mock_exm.asearch = AsyncMock(
return_value=[
{"content": "External data 1"},
{"content": "External data 2"},
]
)
contextual_memory = ContextualMemory(
stm=None,
ltm=None,
em=None,
exm=mock_exm,
)
result = await contextual_memory._afetch_external_context("test query")
assert "External memories:" in result
assert "- External data 1" in result
assert "- External data 2" in result
@pytest.mark.asyncio
async def test_afetch_methods_return_empty_for_empty_results(self):
"""Test that async fetch methods return empty string for no results."""
mock_stm = MagicMock(spec=ShortTermMemory)
mock_stm.asearch = AsyncMock(return_value=[])
mock_ltm = MagicMock(spec=LongTermMemory)
mock_ltm.asearch = AsyncMock(return_value=[])
mock_em = MagicMock(spec=EntityMemory)
mock_em.asearch = AsyncMock(return_value=[])
mock_exm = MagicMock(spec=ExternalMemory)
mock_exm.asearch = AsyncMock(return_value=[])
contextual_memory = ContextualMemory(
stm=mock_stm,
ltm=mock_ltm,
em=mock_em,
exm=mock_exm,
)
stm_result = await contextual_memory._afetch_stm_context("query")
ltm_result = await contextual_memory._afetch_ltm_context("task")
em_result = await contextual_memory._afetch_entity_context("query")
exm_result = await contextual_memory._afetch_external_context("query")
assert stm_result == ""
assert ltm_result is None
assert em_result == ""
assert exm_result == ""