mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: async memory support
Adds async support for tools with tests, async execution in the agent executor, and async operations for memory (with aiosqlite). Improves tool decorator typing, ensures _run backward compatibility, updates docs and docstrings, adds tests, and regenerates lockfiles.
This commit is contained in:
@@ -38,6 +38,7 @@ dependencies = [
|
||||
"pydantic-settings~=2.10.1",
|
||||
"mcp~=1.16.0",
|
||||
"uv~=0.9.13",
|
||||
"aiosqlite~=0.21.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.memory import (
|
||||
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ContextualMemory:
|
||||
"""Aggregates and retrieves context from multiple memory sources."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stm: ShortTermMemory,
|
||||
@@ -46,9 +49,14 @@ class ContextualMemory:
|
||||
self.exm.task = self.task
|
||||
|
||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
"""Build contextual information for a task synchronously.
|
||||
|
||||
Args:
|
||||
task: The task to build context for.
|
||||
context: Additional context string.
|
||||
|
||||
Returns:
|
||||
Formatted context string from all memory sources.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
@@ -63,6 +71,31 @@ class ContextualMemory:
|
||||
]
|
||||
return "\n".join(filter(None, context_parts))
|
||||
|
||||
async def abuild_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""Build contextual information for a task asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task to build context for.
|
||||
context: Additional context string.
|
||||
|
||||
Returns:
|
||||
Formatted context string from all memory sources.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
# Fetch all contexts concurrently
|
||||
results = await asyncio.gather(
|
||||
self._afetch_ltm_context(task.description),
|
||||
self._afetch_stm_context(query),
|
||||
self._afetch_entity_context(query),
|
||||
self._afetch_external_context(query),
|
||||
)
|
||||
|
||||
return "\n".join(filter(None, results))
|
||||
|
||||
def _fetch_stm_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
@@ -135,3 +168,87 @@ class ContextualMemory:
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
async def _afetch_stm_context(self, query: str) -> str:
|
||||
"""Fetch recent relevant insights from STM asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted insights as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.stm is None:
|
||||
return ""
|
||||
|
||||
stm_results = await self.stm.asearch(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in stm_results]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
async def _afetch_ltm_context(self, task: str) -> str | None:
|
||||
"""Fetch historical data from LTM asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
|
||||
Returns:
|
||||
Formatted historical data as bullet points, or None if none found.
|
||||
"""
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
ltm_results = await self.ltm.asearch(task, latest_n=2)
|
||||
if not ltm_results:
|
||||
return None
|
||||
|
||||
formatted_results = [
|
||||
suggestion
|
||||
for result in ltm_results
|
||||
for suggestion in result["metadata"]["suggestions"]
|
||||
]
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
async def _afetch_entity_context(self, query: str) -> str:
|
||||
"""Fetch relevant entity information asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted entity information as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.em is None:
|
||||
return ""
|
||||
|
||||
em_results = await self.em.asearch(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in em_results]
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
async def _afetch_external_context(self, query: str) -> str:
|
||||
"""Fetch relevant information from External Memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted information as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.exm is None:
|
||||
return ""
|
||||
|
||||
external_memories = await self.exm.asearch(query)
|
||||
|
||||
if not external_memories:
|
||||
return ""
|
||||
|
||||
formatted_memories = "\n".join(
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
@@ -26,7 +26,13 @@ class EntityMemory(Memory):
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
@@ -43,7 +49,7 @@ class EntityMemory(Memory):
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
@@ -170,7 +176,17 @@ class EntityMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search entity memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -217,6 +233,168 @@ class EntityMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save entity items asynchronously.
|
||||
|
||||
Args:
|
||||
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||
metadata: Optional metadata dict (not used, for signature compatibility).
|
||||
"""
|
||||
if not value:
|
||||
return
|
||||
|
||||
items = value if isinstance(value, list) else [value]
|
||||
is_batch = len(items) > 1
|
||||
|
||||
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
metadata=metadata,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
saved_count = 0
|
||||
errors: list[str | None] = []
|
||||
|
||||
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||
"""Save a single item asynchronously."""
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
await super(EntityMemory, self).asave(data, item.metadata)
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"{item.name}: {e!s}"
|
||||
|
||||
try:
|
||||
for item in items:
|
||||
success, error = await save_single_item(item)
|
||||
if success:
|
||||
saved_count += 1
|
||||
else:
|
||||
errors.append(error)
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
metadata = {"entity_count": saved_count, "errors": errors}
|
||||
else:
|
||||
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||
metadata = items[0].metadata
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=emit_value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception(
|
||||
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
fail_metadata = (
|
||||
{"entity_count": len(items), "saved": saved_count}
|
||||
if is_batch
|
||||
else items[0].metadata
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
metadata=fail_metadata,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search entity memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await super().asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -30,7 +30,7 @@ class ExternalMemory(Memory):
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config)
|
||||
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
@@ -53,7 +53,10 @@ class ExternalMemory(Memory):
|
||||
if provider not in supported_storages:
|
||||
raise ValueError(f"Provider {provider} not supported")
|
||||
|
||||
return supported_storages[provider](crew, embedder_config.get("config", {}))
|
||||
storage: Storage = supported_storages[provider](
|
||||
crew, embedder_config.get("config", {})
|
||||
)
|
||||
return storage
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -111,7 +114,17 @@ class ExternalMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search external memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -158,6 +171,124 @@ class ExternalMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to external memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ExternalMemoryItem(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
await super().asave(value=item.value, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search external memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await super().asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
self.storage.reset()
|
||||
|
||||
|
||||
@@ -24,7 +24,11 @@ class LongTermMemory(Memory):
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
storage: LTMSQLiteStorage | None = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
@@ -48,7 +52,7 @@ class LongTermMemory(Memory):
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
@@ -80,11 +84,20 @@ class LongTermMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
def search( # type: ignore[override]
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -98,7 +111,7 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
results = self.storage.load(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -113,7 +126,118 @@ class LongTermMemory(Memory):
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
return results or []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
|
||||
"""Save an item to long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
item: The LongTermMemoryItem to save.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
await self.storage.asave(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch( # type: ignore[override]
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.aload(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=task,
|
||||
results=results,
|
||||
limit=latest_n,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results or []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -127,4 +251,5 @@ class LongTermMemory(Memory):
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset long-term memory."""
|
||||
self.storage.reset()
|
||||
|
||||
@@ -13,9 +13,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
"""Base class for memory, supporting agent tags and generic metadata."""
|
||||
|
||||
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
@@ -52,20 +50,72 @@ class Memory(BaseModel):
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
"""Save a value to memory.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
await self.storage.asave(value, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
return self.storage.search(
|
||||
"""Search memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
results: list[Any] = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
return results
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search memory for relevant entries asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
results: list[Any] = await self.storage.asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
return results
|
||||
|
||||
def set_crew(self, crew: Any) -> Memory:
|
||||
"""Set the crew for this memory instance."""
|
||||
self.crew = crew
|
||||
return self
|
||||
|
||||
@@ -30,7 +30,13 @@ class ShortTermMemory(Memory):
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
@@ -47,7 +53,7 @@ class ShortTermMemory(Memory):
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
@@ -123,7 +129,17 @@ class ShortTermMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search short-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -140,7 +156,7 @@ class ShortTermMemory(Memory):
|
||||
try:
|
||||
results = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -156,7 +172,130 @@ class ShortTermMemory(Memory):
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to short-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ShortTermMemoryItem(
|
||||
data=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
if self._memory_provider == "mem0":
|
||||
item.data = (
|
||||
f"Remember the following insights from Agent run: {item.data}"
|
||||
)
|
||||
|
||||
await super().asave(value=item.data, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search short-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -3,29 +3,30 @@ from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class LTMSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
"""SQLite storage class for long-term memory data."""
|
||||
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
"""Initialize the SQLite storage.
|
||||
|
||||
Args:
|
||||
db_path: Optional path to the database file.
|
||||
"""
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
# Ensure parent directory exists
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""
|
||||
Initializes the SQLite database and creates LTM table
|
||||
"""
|
||||
def _initialize_db(self) -> None:
|
||||
"""Initialize the SQLite database and create LTM table."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
@@ -106,9 +107,7 @@ class LTMSQLiteStorage:
|
||||
)
|
||||
return None
|
||||
|
||||
def reset(
|
||||
self,
|
||||
) -> None:
|
||||
def reset(self) -> None:
|
||||
"""Resets the LTM table with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -121,4 +120,87 @@ class LTMSQLiteStorage:
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: int | float,
|
||||
) -> None:
|
||||
"""Save data to the LTM table asynchronously.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task.
|
||||
metadata: Metadata associated with the memory.
|
||||
datetime: Timestamp of the memory.
|
||||
score: Quality score of the memory.
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(task_description, json.dumps(metadata), datetime, score),
|
||||
)
|
||||
await conn.commit()
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
async def aload(
|
||||
self, task_description: str, latest_n: int
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Query the LTM table by task description asynchronously.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries or None if error occurs.
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
cursor = await conn.execute(
|
||||
f"""
|
||||
SELECT metadata, datetime, score
|
||||
FROM long_term_memories
|
||||
WHERE task_description = ?
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT {latest_n}
|
||||
""", # nosec # noqa: S608
|
||||
(task_description,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
if rows:
|
||||
return [
|
||||
{
|
||||
"metadata": json.loads(row[0]),
|
||||
"datetime": row[1],
|
||||
"score": row[2],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the LTM table asynchronously."""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
await conn.execute("DELETE FROM long_term_memories")
|
||||
await conn.commit()
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
@@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage):
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value to storage.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Metadata to associate with the value.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
@@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage):
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value to storage asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Metadata to associate with the value.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
await client.aget_or_create_collection(collection_name=collection_name)
|
||||
|
||||
document: BaseRecord = {"content": value}
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
self.embedder_config
|
||||
and isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name, documents=[document]
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage):
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for matching entries in storage.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
filter: Optional metadata filter.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching entries.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
@@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
return []
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for matching entries in storage asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
filter: Optional metadata filter.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching entries.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
return await client.asearch(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
limit=limit,
|
||||
metadata_filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
|
||||
496
lib/crewai/tests/memory/test_async_memory.py
Normal file
496
lib/crewai/tests/memory/test_async_memory.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""Tests for async memory operations."""
|
||||
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryQueryCompletedEvent,
|
||||
MemoryQueryStartedEvent,
|
||||
MemorySaveCompletedEvent,
|
||||
MemorySaveStartedEvent,
|
||||
)
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.external.external_memory import ExternalMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Fixture to create a mock agent."""
|
||||
return Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task(mock_agent):
|
||||
"""Fixture to create a mock task."""
|
||||
return Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=mock_agent,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory(mock_agent, mock_task):
|
||||
"""Fixture to create a ShortTermMemory instance."""
|
||||
return ShortTermMemory(crew=Crew(agents=[mock_agent], tasks=[mock_task]))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory(tmp_path):
|
||||
"""Fixture to create a LongTermMemory instance."""
|
||||
db_path = str(tmp_path / "test_ltm.db")
|
||||
return LongTermMemory(path=db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_memory(tmp_path, mock_agent, mock_task):
|
||||
"""Fixture to create an EntityMemory instance."""
|
||||
return EntityMemory(
|
||||
crew=Crew(agents=[mock_agent], tasks=[mock_task]),
|
||||
path=str(tmp_path / "test_entities"),
|
||||
)
|
||||
|
||||
|
||||
class TestAsyncShortTermMemory:
|
||||
"""Tests for async ShortTermMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_emits_events(self, short_term_memory):
|
||||
"""Test that asave emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
await short_term_memory.asave(
|
||||
value="async test value",
|
||||
metadata={"task": "async_test_task"},
|
||||
)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].value == "async test value"
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "short_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, short_term_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
with patch.object(short_term_memory.storage, "asearch", new_callable=AsyncMock, return_value=[]):
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await short_term_memory.asearch(
|
||||
query="async test query",
|
||||
limit=3,
|
||||
score_threshold=0.35,
|
||||
)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].query == "async test query"
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "short_term_memory"
|
||||
|
||||
|
||||
class TestAsyncLongTermMemory:
|
||||
"""Tests for async LongTermMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_emits_events(self, long_term_memory):
|
||||
"""Test that asave emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
item = LongTermMemoryItem(
|
||||
task="async test task",
|
||||
agent="test_agent",
|
||||
expected_output="test output",
|
||||
datetime="2024-01-01T00:00:00",
|
||||
quality=0.9,
|
||||
metadata={"task": "async test task", "quality": 0.9},
|
||||
)
|
||||
|
||||
await long_term_memory.asave(item)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "long_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, long_term_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await long_term_memory.asearch(task="async test task", latest_n=3)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "long_term_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_and_asearch_integration(self, long_term_memory):
|
||||
"""Test that asave followed by asearch works correctly."""
|
||||
item = LongTermMemoryItem(
|
||||
task="integration test task",
|
||||
agent="test_agent",
|
||||
expected_output="test output",
|
||||
datetime="2024-01-01T00:00:00",
|
||||
quality=0.9,
|
||||
metadata={"task": "integration test task", "quality": 0.9},
|
||||
)
|
||||
|
||||
await long_term_memory.asave(item)
|
||||
results = await long_term_memory.asearch(task="integration test task", latest_n=1)
|
||||
|
||||
assert results is not None
|
||||
assert len(results) == 1
|
||||
assert results[0]["metadata"]["agent"] == "test_agent"
|
||||
|
||||
|
||||
class TestAsyncEntityMemory:
|
||||
"""Tests for async EntityMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asave_single_item_emits_events(self, entity_memory):
|
||||
"""Test that asave with a single item emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
condition = threading.Condition()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||
def on_save_started(source, event):
|
||||
with condition:
|
||||
events["MemorySaveStartedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_save_completed(source, event):
|
||||
with condition:
|
||||
events["MemorySaveCompletedEvent"].append(event)
|
||||
condition.notify()
|
||||
|
||||
item = EntityMemoryItem(
|
||||
name="TestEntity",
|
||||
type="Person",
|
||||
description="A test entity for async operations",
|
||||
relationships="Related to other test entities",
|
||||
)
|
||||
|
||||
await entity_memory.asave(item)
|
||||
|
||||
with condition:
|
||||
success = condition.wait_for(
|
||||
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||
timeout=5,
|
||||
)
|
||||
assert success, "Timeout waiting for async save events"
|
||||
|
||||
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||
assert events["MemorySaveStartedEvent"][-1].source_type == "entity_memory"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_emits_events(self, entity_memory):
|
||||
"""Test that asearch emits the correct events."""
|
||||
events: dict[str, list] = defaultdict(list)
|
||||
search_started = threading.Event()
|
||||
search_completed = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||
def on_search_started(source, event):
|
||||
events["MemoryQueryStartedEvent"].append(event)
|
||||
search_started.set()
|
||||
|
||||
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||
def on_search_completed(source, event):
|
||||
events["MemoryQueryCompletedEvent"].append(event)
|
||||
search_completed.set()
|
||||
|
||||
await entity_memory.asearch(query="TestEntity", limit=5, score_threshold=0.6)
|
||||
|
||||
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||
|
||||
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||
assert events["MemoryQueryStartedEvent"][-1].source_type == "entity_memory"
|
||||
|
||||
|
||||
class TestAsyncContextualMemory:
|
||||
"""Tests for async ContextualMemory operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_with_empty_query(self, mock_task):
|
||||
"""Test that abuild_context_for_task returns empty string for empty query."""
|
||||
mock_task.description = ""
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_with_none_memories(self, mock_task):
|
||||
"""Test that abuild_context_for_task handles None memory sources."""
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "some context")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abuild_context_for_task_aggregates_results(self, mock_agent, mock_task):
|
||||
"""Test that abuild_context_for_task aggregates results from all memory sources."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(return_value=[{"content": "STM insight"}])
|
||||
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(
|
||||
return_value=[{"metadata": {"suggestions": ["LTM suggestion"]}}]
|
||||
)
|
||||
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(return_value=[{"content": "Entity info"}])
|
||||
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(return_value=[{"content": "External memory"}])
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=mock_ltm,
|
||||
em=mock_em,
|
||||
exm=mock_exm,
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
)
|
||||
|
||||
result = await contextual_memory.abuild_context_for_task(mock_task, "additional context")
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "STM insight" in result
|
||||
assert "Historical Data:" in result
|
||||
assert "LTM suggestion" in result
|
||||
assert "Entities:" in result
|
||||
assert "Entity info" in result
|
||||
assert "External memories:" in result
|
||||
assert "External memory" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_stm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_stm_context returns properly formatted results."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "First insight"},
|
||||
{"content": "Second insight"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_stm_context("test query")
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "- First insight" in result
|
||||
assert "- Second insight" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_ltm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_ltm_context returns properly formatted results."""
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"metadata": {"suggestions": ["Suggestion 1", "Suggestion 2"]}},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=mock_ltm,
|
||||
em=None,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_ltm_context("test task")
|
||||
|
||||
assert "Historical Data:" in result
|
||||
assert "- Suggestion 1" in result
|
||||
assert "- Suggestion 2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_entity_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||
"""Test that _afetch_entity_context returns properly formatted results."""
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "Entity A details"},
|
||||
{"content": "Entity B details"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=mock_em,
|
||||
exm=None,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_entity_context("test query")
|
||||
|
||||
assert "Entities:" in result
|
||||
assert "- Entity A details" in result
|
||||
assert "- Entity B details" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_external_context_returns_formatted_results(self):
|
||||
"""Test that _afetch_external_context returns properly formatted results."""
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(
|
||||
return_value=[
|
||||
{"content": "External data 1"},
|
||||
{"content": "External data 2"},
|
||||
]
|
||||
)
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=None,
|
||||
ltm=None,
|
||||
em=None,
|
||||
exm=mock_exm,
|
||||
)
|
||||
|
||||
result = await contextual_memory._afetch_external_context("test query")
|
||||
|
||||
assert "External memories:" in result
|
||||
assert "- External data 1" in result
|
||||
assert "- External data 2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_afetch_methods_return_empty_for_empty_results(self):
|
||||
"""Test that async fetch methods return empty string for no results."""
|
||||
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||
mock_stm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||
mock_ltm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_em = MagicMock(spec=EntityMemory)
|
||||
mock_em.asearch = AsyncMock(return_value=[])
|
||||
|
||||
mock_exm = MagicMock(spec=ExternalMemory)
|
||||
mock_exm.asearch = AsyncMock(return_value=[])
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
stm=mock_stm,
|
||||
ltm=mock_ltm,
|
||||
em=mock_em,
|
||||
exm=mock_exm,
|
||||
)
|
||||
|
||||
stm_result = await contextual_memory._afetch_stm_context("query")
|
||||
ltm_result = await contextual_memory._afetch_ltm_context("task")
|
||||
em_result = await contextual_memory._afetch_entity_context("query")
|
||||
exm_result = await contextual_memory._afetch_external_context("query")
|
||||
|
||||
assert stm_result == ""
|
||||
assert ltm_result is None
|
||||
assert em_result == ""
|
||||
assert exm_result == ""
|
||||
Reference in New Issue
Block a user