mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: add async ops to memory feat; create tests
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from crewai.memory import (
|
from crewai.memory import (
|
||||||
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class ContextualMemory:
|
class ContextualMemory:
|
||||||
|
"""Aggregates and retrieves context from multiple memory sources."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stm: ShortTermMemory,
|
stm: ShortTermMemory,
|
||||||
@@ -46,9 +49,14 @@ class ContextualMemory:
|
|||||||
self.exm.task = self.task
|
self.exm.task = self.task
|
||||||
|
|
||||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||||
"""
|
"""Build contextual information for a task synchronously.
|
||||||
Automatically builds a minimal, highly relevant set of contextual information
|
|
||||||
for a given task.
|
Args:
|
||||||
|
task: The task to build context for.
|
||||||
|
context: Additional context string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string from all memory sources.
|
||||||
"""
|
"""
|
||||||
query = f"{task.description} {context}".strip()
|
query = f"{task.description} {context}".strip()
|
||||||
|
|
||||||
@@ -63,6 +71,31 @@ class ContextualMemory:
|
|||||||
]
|
]
|
||||||
return "\n".join(filter(None, context_parts))
|
return "\n".join(filter(None, context_parts))
|
||||||
|
|
||||||
|
async def abuild_context_for_task(self, task: Task, context: str) -> str:
|
||||||
|
"""Build contextual information for a task asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task to build context for.
|
||||||
|
context: Additional context string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string from all memory sources.
|
||||||
|
"""
|
||||||
|
query = f"{task.description} {context}".strip()
|
||||||
|
|
||||||
|
if query == "":
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Fetch all contexts concurrently
|
||||||
|
results = await asyncio.gather(
|
||||||
|
self._afetch_ltm_context(task.description),
|
||||||
|
self._afetch_stm_context(query),
|
||||||
|
self._afetch_entity_context(query),
|
||||||
|
self._afetch_external_context(query),
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(filter(None, results))
|
||||||
|
|
||||||
def _fetch_stm_context(self, query: str) -> str:
|
def _fetch_stm_context(self, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||||
@@ -135,3 +168,87 @@ class ContextualMemory:
|
|||||||
f"- {result['content']}" for result in external_memories
|
f"- {result['content']}" for result in external_memories
|
||||||
)
|
)
|
||||||
return f"External memories:\n{formatted_memories}"
|
return f"External memories:\n{formatted_memories}"
|
||||||
|
|
||||||
|
async def _afetch_stm_context(self, query: str) -> str:
|
||||||
|
"""Fetch recent relevant insights from STM asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted insights as bullet points, or empty string if none found.
|
||||||
|
"""
|
||||||
|
if self.stm is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
stm_results = await self.stm.asearch(query)
|
||||||
|
formatted_results = "\n".join(
|
||||||
|
[f"- {result['content']}" for result in stm_results]
|
||||||
|
)
|
||||||
|
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||||
|
|
||||||
|
async def _afetch_ltm_context(self, task: str) -> str | None:
|
||||||
|
"""Fetch historical data from LTM asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task description to search for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted historical data as bullet points, or None if none found.
|
||||||
|
"""
|
||||||
|
if self.ltm is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
ltm_results = await self.ltm.asearch(task, latest_n=2)
|
||||||
|
if not ltm_results:
|
||||||
|
return None
|
||||||
|
|
||||||
|
formatted_results = [
|
||||||
|
suggestion
|
||||||
|
for result in ltm_results
|
||||||
|
for suggestion in result["metadata"]["suggestions"]
|
||||||
|
]
|
||||||
|
formatted_results = list(dict.fromkeys(formatted_results))
|
||||||
|
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||||
|
|
||||||
|
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||||
|
|
||||||
|
async def _afetch_entity_context(self, query: str) -> str:
|
||||||
|
"""Fetch relevant entity information asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted entity information as bullet points, or empty string if none found.
|
||||||
|
"""
|
||||||
|
if self.em is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
em_results = await self.em.asearch(query)
|
||||||
|
formatted_results = "\n".join(
|
||||||
|
[f"- {result['content']}" for result in em_results]
|
||||||
|
)
|
||||||
|
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||||
|
|
||||||
|
async def _afetch_external_context(self, query: str) -> str:
|
||||||
|
"""Fetch relevant information from External Memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted information as bullet points, or empty string if none found.
|
||||||
|
"""
|
||||||
|
if self.exm is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
external_memories = await self.exm.asearch(query)
|
||||||
|
|
||||||
|
if not external_memories:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
formatted_memories = "\n".join(
|
||||||
|
f"- {result['content']}" for result in external_memories
|
||||||
|
)
|
||||||
|
return f"External memories:\n{formatted_memories}"
|
||||||
|
|||||||
@@ -26,7 +26,13 @@ class EntityMemory(Memory):
|
|||||||
|
|
||||||
_memory_provider: str | None = PrivateAttr()
|
_memory_provider: str | None = PrivateAttr()
|
||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
crew: Any = None,
|
||||||
|
embedder_config: Any = None,
|
||||||
|
storage: Any = None,
|
||||||
|
path: str | None = None,
|
||||||
|
) -> None:
|
||||||
memory_provider = None
|
memory_provider = None
|
||||||
if embedder_config and isinstance(embedder_config, dict):
|
if embedder_config and isinstance(embedder_config, dict):
|
||||||
memory_provider = embedder_config.get("provider")
|
memory_provider = embedder_config.get("provider")
|
||||||
@@ -43,7 +49,7 @@ class EntityMemory(Memory):
|
|||||||
if embedder_config and isinstance(embedder_config, dict)
|
if embedder_config and isinstance(embedder_config, dict)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||||
else:
|
else:
|
||||||
storage = (
|
storage = (
|
||||||
storage
|
storage
|
||||||
@@ -170,7 +176,17 @@ class EntityMemory(Memory):
|
|||||||
query: str,
|
query: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.6,
|
score_threshold: float = 0.6,
|
||||||
):
|
) -> list[Any]:
|
||||||
|
"""Search entity memory for relevant entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryQueryStartedEvent(
|
event=MemoryQueryStartedEvent(
|
||||||
@@ -217,6 +233,168 @@ class EntityMemory(Memory):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def asave(
|
||||||
|
self,
|
||||||
|
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Save entity items asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||||
|
metadata: Optional metadata dict (not used, for signature compatibility).
|
||||||
|
"""
|
||||||
|
if not value:
|
||||||
|
return
|
||||||
|
|
||||||
|
items = value if isinstance(value, list) else [value]
|
||||||
|
is_batch = len(items) > 1
|
||||||
|
|
||||||
|
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveStartedEvent(
|
||||||
|
metadata=metadata,
|
||||||
|
source_type="entity_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
saved_count = 0
|
||||||
|
errors: list[str | None] = []
|
||||||
|
|
||||||
|
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||||
|
"""Save a single item asynchronously."""
|
||||||
|
try:
|
||||||
|
if self._memory_provider == "mem0":
|
||||||
|
data = f"""
|
||||||
|
Remember details about the following entity:
|
||||||
|
Name: {item.name}
|
||||||
|
Type: {item.type}
|
||||||
|
Entity Description: {item.description}
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
data = f"{item.name}({item.type}): {item.description}"
|
||||||
|
|
||||||
|
await super(EntityMemory, self).asave(data, item.metadata)
|
||||||
|
return True, None
|
||||||
|
except Exception as e:
|
||||||
|
return False, f"{item.name}: {e!s}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
for item in items:
|
||||||
|
success, error = await save_single_item(item)
|
||||||
|
if success:
|
||||||
|
saved_count += 1
|
||||||
|
else:
|
||||||
|
errors.append(error)
|
||||||
|
|
||||||
|
if is_batch:
|
||||||
|
emit_value = f"Saved {saved_count} entities"
|
||||||
|
metadata = {"entity_count": saved_count, "errors": errors}
|
||||||
|
else:
|
||||||
|
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||||
|
metadata = items[0].metadata
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveCompletedEvent(
|
||||||
|
value=emit_value,
|
||||||
|
metadata=metadata,
|
||||||
|
save_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="entity_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
raise Exception(
|
||||||
|
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
fail_metadata = (
|
||||||
|
{"entity_count": len(items), "saved": saved_count}
|
||||||
|
if is_batch
|
||||||
|
else items[0].metadata
|
||||||
|
)
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveFailedEvent(
|
||||||
|
metadata=fail_metadata,
|
||||||
|
error=str(e),
|
||||||
|
source_type="entity_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Search entity memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryStartedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
source_type="entity_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
results = await super().asearch(
|
||||||
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryCompletedEvent(
|
||||||
|
query=query,
|
||||||
|
results=results,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
query_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="entity_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryFailedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
error=str(e),
|
||||||
|
source_type="entity_memory",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
try:
|
try:
|
||||||
self.storage.reset()
|
self.storage.reset()
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ class ExternalMemory(Memory):
|
|||||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
|
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
|
||||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||||
|
|
||||||
return Mem0Storage(type="external", crew=crew, config=config)
|
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def external_supported_storages() -> dict[str, Any]:
|
def external_supported_storages() -> dict[str, Any]:
|
||||||
@@ -53,7 +53,10 @@ class ExternalMemory(Memory):
|
|||||||
if provider not in supported_storages:
|
if provider not in supported_storages:
|
||||||
raise ValueError(f"Provider {provider} not supported")
|
raise ValueError(f"Provider {provider} not supported")
|
||||||
|
|
||||||
return supported_storages[provider](crew, embedder_config.get("config", {}))
|
storage: Storage = supported_storages[provider](
|
||||||
|
crew, embedder_config.get("config", {})
|
||||||
|
)
|
||||||
|
return storage
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
@@ -111,7 +114,17 @@ class ExternalMemory(Memory):
|
|||||||
query: str,
|
query: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.6,
|
score_threshold: float = 0.6,
|
||||||
):
|
) -> list[Any]:
|
||||||
|
"""Search external memory for relevant entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryQueryStartedEvent(
|
event=MemoryQueryStartedEvent(
|
||||||
@@ -158,6 +171,124 @@ class ExternalMemory(Memory):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def asave(
|
||||||
|
self,
|
||||||
|
value: Any,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Save a value to external memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Optional metadata to associate with the value.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveStartedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
source_type="external_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
item = ExternalMemoryItem(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
agent=self.agent.role if self.agent else None,
|
||||||
|
)
|
||||||
|
await super().asave(value=item.value, metadata=item.metadata)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveCompletedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
save_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="external_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveFailedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
error=str(e),
|
||||||
|
source_type="external_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Search external memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryStartedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
source_type="external_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
results = await super().asearch(
|
||||||
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryCompletedEvent(
|
||||||
|
query=query,
|
||||||
|
results=results,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
query_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="external_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryFailedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
error=str(e),
|
||||||
|
source_type="external_memory",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.storage.reset()
|
self.storage.reset()
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,11 @@ class LongTermMemory(Memory):
|
|||||||
LongTermMemoryItem instances.
|
LongTermMemoryItem instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, storage=None, path=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage: LTMSQLiteStorage | None = None,
|
||||||
|
path: str | None = None,
|
||||||
|
) -> None:
|
||||||
if not storage:
|
if not storage:
|
||||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||||
super().__init__(storage=storage)
|
super().__init__(storage=storage)
|
||||||
@@ -48,7 +52,7 @@ class LongTermMemory(Memory):
|
|||||||
metadata.update(
|
metadata.update(
|
||||||
{"agent": item.agent, "expected_output": item.expected_output}
|
{"agent": item.agent, "expected_output": item.expected_output}
|
||||||
)
|
)
|
||||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
self.storage.save(
|
||||||
task_description=item.task,
|
task_description=item.task,
|
||||||
score=metadata["quality"],
|
score=metadata["quality"],
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
@@ -80,11 +84,20 @@ class LongTermMemory(Memory):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
|
def search( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
task: str,
|
task: str,
|
||||||
latest_n: int = 3,
|
latest_n: int = 3,
|
||||||
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Search long-term memory for relevant entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task description to search for.
|
||||||
|
latest_n: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryQueryStartedEvent(
|
event=MemoryQueryStartedEvent(
|
||||||
@@ -98,7 +111,7 @@ class LongTermMemory(Memory):
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
results = self.storage.load(task, latest_n)
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -113,7 +126,118 @@ class LongTermMemory(Memory):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results or []
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryFailedEvent(
|
||||||
|
query=task,
|
||||||
|
limit=latest_n,
|
||||||
|
error=str(e),
|
||||||
|
source_type="long_term_memory",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
|
||||||
|
"""Save an item to long-term memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: The LongTermMemoryItem to save.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveStartedEvent(
|
||||||
|
value=item.task,
|
||||||
|
metadata=item.metadata,
|
||||||
|
agent_role=item.agent,
|
||||||
|
source_type="long_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
metadata = item.metadata
|
||||||
|
metadata.update(
|
||||||
|
{"agent": item.agent, "expected_output": item.expected_output}
|
||||||
|
)
|
||||||
|
await self.storage.asave(
|
||||||
|
task_description=item.task,
|
||||||
|
score=metadata["quality"],
|
||||||
|
metadata=metadata,
|
||||||
|
datetime=item.datetime,
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveCompletedEvent(
|
||||||
|
value=item.task,
|
||||||
|
metadata=item.metadata,
|
||||||
|
agent_role=item.agent,
|
||||||
|
save_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="long_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveFailedEvent(
|
||||||
|
value=item.task,
|
||||||
|
metadata=item.metadata,
|
||||||
|
agent_role=item.agent,
|
||||||
|
error=str(e),
|
||||||
|
source_type="long_term_memory",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asearch( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
latest_n: int = 3,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Search long-term memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: The task description to search for.
|
||||||
|
latest_n: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryStartedEvent(
|
||||||
|
query=task,
|
||||||
|
limit=latest_n,
|
||||||
|
source_type="long_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
results = await self.storage.aload(task, latest_n)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryCompletedEvent(
|
||||||
|
query=task,
|
||||||
|
results=results,
|
||||||
|
limit=latest_n,
|
||||||
|
query_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="long_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return results or []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -127,4 +251,5 @@ class LongTermMemory(Memory):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
"""Reset long-term memory."""
|
||||||
self.storage.reset()
|
self.storage.reset()
|
||||||
|
|||||||
@@ -13,9 +13,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class Memory(BaseModel):
|
class Memory(BaseModel):
|
||||||
"""
|
"""Base class for memory, supporting agent tags and generic metadata."""
|
||||||
Base class for memory, now supporting agent tags and generic metadata.
|
|
||||||
"""
|
|
||||||
|
|
||||||
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
||||||
crew: Any | None = None
|
crew: Any | None = None
|
||||||
@@ -52,20 +50,72 @@ class Memory(BaseModel):
|
|||||||
value: Any,
|
value: Any,
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
metadata = metadata or {}
|
"""Save a value to memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Optional metadata to associate with the value.
|
||||||
|
"""
|
||||||
|
metadata = metadata or {}
|
||||||
self.storage.save(value, metadata)
|
self.storage.save(value, metadata)
|
||||||
|
|
||||||
|
async def asave(
|
||||||
|
self,
|
||||||
|
value: Any,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Save a value to memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Optional metadata to associate with the value.
|
||||||
|
"""
|
||||||
|
metadata = metadata or {}
|
||||||
|
await self.storage.asave(value, metadata)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.6,
|
score_threshold: float = 0.6,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
return self.storage.search(
|
"""Search memory for relevant entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
results: list[Any] = self.storage.search(
|
||||||
query=query, limit=limit, score_threshold=score_threshold
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
)
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Search memory for relevant entries asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
results: list[Any] = await self.storage.asearch(
|
||||||
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
def set_crew(self, crew: Any) -> Memory:
|
def set_crew(self, crew: Any) -> Memory:
|
||||||
|
"""Set the crew for this memory instance."""
|
||||||
self.crew = crew
|
self.crew = crew
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -30,7 +30,13 @@ class ShortTermMemory(Memory):
|
|||||||
|
|
||||||
_memory_provider: str | None = PrivateAttr()
|
_memory_provider: str | None = PrivateAttr()
|
||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
crew: Any = None,
|
||||||
|
embedder_config: Any = None,
|
||||||
|
storage: Any = None,
|
||||||
|
path: str | None = None,
|
||||||
|
) -> None:
|
||||||
memory_provider = None
|
memory_provider = None
|
||||||
if embedder_config and isinstance(embedder_config, dict):
|
if embedder_config and isinstance(embedder_config, dict):
|
||||||
memory_provider = embedder_config.get("provider")
|
memory_provider = embedder_config.get("provider")
|
||||||
@@ -47,7 +53,7 @@ class ShortTermMemory(Memory):
|
|||||||
if embedder_config and isinstance(embedder_config, dict)
|
if embedder_config and isinstance(embedder_config, dict)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||||
else:
|
else:
|
||||||
storage = (
|
storage = (
|
||||||
storage
|
storage
|
||||||
@@ -123,7 +129,17 @@ class ShortTermMemory(Memory):
|
|||||||
query: str,
|
query: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
score_threshold: float = 0.6,
|
score_threshold: float = 0.6,
|
||||||
):
|
) -> list[Any]:
|
||||||
|
"""Search short-term memory for relevant entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
event=MemoryQueryStartedEvent(
|
event=MemoryQueryStartedEvent(
|
||||||
@@ -140,7 +156,7 @@ class ShortTermMemory(Memory):
|
|||||||
try:
|
try:
|
||||||
results = self.storage.search(
|
results = self.storage.search(
|
||||||
query=query, limit=limit, score_threshold=score_threshold
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
)
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -156,7 +172,130 @@ class ShortTermMemory(Memory):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return list(results)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryFailedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
error=str(e),
|
||||||
|
source_type="short_term_memory",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asave(
|
||||||
|
self,
|
||||||
|
value: Any,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Save a value to short-term memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Optional metadata to associate with the value.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveStartedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
source_type="short_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
item = ShortTermMemoryItem(
|
||||||
|
data=value,
|
||||||
|
metadata=metadata,
|
||||||
|
agent=self.agent.role if self.agent else None,
|
||||||
|
)
|
||||||
|
if self._memory_provider == "mem0":
|
||||||
|
item.data = (
|
||||||
|
f"Remember the following insights from Agent run: {item.data}"
|
||||||
|
)
|
||||||
|
|
||||||
|
await super().asave(value=item.data, metadata=item.metadata)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveCompletedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
save_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="short_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemorySaveFailedEvent(
|
||||||
|
value=value,
|
||||||
|
metadata=metadata,
|
||||||
|
error=str(e),
|
||||||
|
source_type="short_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Search short-term memory asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries.
|
||||||
|
"""
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryStartedEvent(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
source_type="short_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
results = await self.storage.asearch(
|
||||||
|
query=query, limit=limit, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
crewai_event_bus.emit(
|
||||||
|
self,
|
||||||
|
event=MemoryQueryCompletedEvent(
|
||||||
|
query=query,
|
||||||
|
results=results,
|
||||||
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
query_time_ms=(time.time() - start_time) * 1000,
|
||||||
|
source_type="short_term_memory",
|
||||||
|
from_agent=self.agent,
|
||||||
|
from_task=self.task,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(results)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -3,29 +3,30 @@ from pathlib import Path
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
from crewai.utilities import Printer
|
from crewai.utilities import Printer
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
|
|
||||||
class LTMSQLiteStorage:
|
class LTMSQLiteStorage:
|
||||||
"""
|
"""SQLite storage class for long-term memory data."""
|
||||||
An updated SQLite storage class for LTM data storage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, db_path: str | None = None) -> None:
|
def __init__(self, db_path: str | None = None) -> None:
|
||||||
|
"""Initialize the SQLite storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Optional path to the database file.
|
||||||
|
"""
|
||||||
if db_path is None:
|
if db_path is None:
|
||||||
# Get the parent directory of the default db path and create our db file there
|
|
||||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
self._printer: Printer = Printer()
|
self._printer: Printer = Printer()
|
||||||
# Ensure parent directory exists
|
|
||||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
self._initialize_db()
|
self._initialize_db()
|
||||||
|
|
||||||
def _initialize_db(self):
|
def _initialize_db(self) -> None:
|
||||||
"""
|
"""Initialize the SQLite database and create LTM table."""
|
||||||
Initializes the SQLite database and creates LTM table
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -106,9 +107,7 @@ class LTMSQLiteStorage:
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def reset(
|
def reset(self) -> None:
|
||||||
self,
|
|
||||||
) -> None:
|
|
||||||
"""Resets the LTM table with error handling."""
|
"""Resets the LTM table with error handling."""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
@@ -121,4 +120,87 @@ class LTMSQLiteStorage:
|
|||||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
return
|
|
||||||
|
async def asave(
|
||||||
|
self,
|
||||||
|
task_description: str,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
datetime: str,
|
||||||
|
score: int | float,
|
||||||
|
) -> None:
|
||||||
|
"""Save data to the LTM table asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_description: Description of the task.
|
||||||
|
metadata: Metadata associated with the memory.
|
||||||
|
datetime: Timestamp of the memory.
|
||||||
|
score: Quality score of the memory.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self.db_path) as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(task_description, json.dumps(metadata), datetime, score),
|
||||||
|
)
|
||||||
|
await conn.commit()
|
||||||
|
except aiosqlite.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aload(
|
||||||
|
self, task_description: str, latest_n: int
|
||||||
|
) -> list[dict[str, Any]] | None:
|
||||||
|
"""Query the LTM table by task description asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_description: Description of the task to search for.
|
||||||
|
latest_n: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory entries or None if error occurs.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self.db_path) as conn:
|
||||||
|
cursor = await conn.execute(
|
||||||
|
f"""
|
||||||
|
SELECT metadata, datetime, score
|
||||||
|
FROM long_term_memories
|
||||||
|
WHERE task_description = ?
|
||||||
|
ORDER BY datetime DESC, score ASC
|
||||||
|
LIMIT {latest_n}
|
||||||
|
""", # nosec # noqa: S608
|
||||||
|
(task_description,),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
if rows:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"metadata": json.loads(row[0]),
|
||||||
|
"datetime": row[1],
|
||||||
|
"score": row[2],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
except aiosqlite.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def areset(self) -> None:
|
||||||
|
"""Reset the LTM table asynchronously."""
|
||||||
|
try:
|
||||||
|
async with aiosqlite.connect(self.db_path) as conn:
|
||||||
|
await conn.execute("DELETE FROM long_term_memories")
|
||||||
|
await conn.commit()
|
||||||
|
except aiosqlite.Error as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|||||||
@@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
return f"{base_path}/{file_name}"
|
return f"{base_path}/{file_name}"
|
||||||
|
|
||||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||||
|
"""Save a value to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Metadata to associate with the value.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
collection_name = (
|
collection_name = (
|
||||||
@@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||||
|
"""Save a value to storage asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: The value to save.
|
||||||
|
metadata: Metadata to associate with the value.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
collection_name = (
|
||||||
|
f"memory_{self.type}_{self.agents}"
|
||||||
|
if self.agents
|
||||||
|
else f"memory_{self.type}"
|
||||||
|
)
|
||||||
|
await client.aget_or_create_collection(collection_name=collection_name)
|
||||||
|
|
||||||
|
document: BaseRecord = {"content": value}
|
||||||
|
if metadata:
|
||||||
|
document["metadata"] = metadata
|
||||||
|
|
||||||
|
batch_size = None
|
||||||
|
if (
|
||||||
|
self.embedder_config
|
||||||
|
and isinstance(self.embedder_config, dict)
|
||||||
|
and "config" in self.embedder_config
|
||||||
|
):
|
||||||
|
nested_config = self.embedder_config["config"]
|
||||||
|
if isinstance(nested_config, dict):
|
||||||
|
batch_size = nested_config.get("batch_size")
|
||||||
|
|
||||||
|
if batch_size is not None:
|
||||||
|
await client.aadd_documents(
|
||||||
|
collection_name=collection_name,
|
||||||
|
documents=[document],
|
||||||
|
batch_size=cast(int, batch_size),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await client.aadd_documents(
|
||||||
|
collection_name=collection_name, documents=[document]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
filter: dict[str, Any] | None = None,
|
filter: dict[str, Any] | None = None,
|
||||||
score_threshold: float = 0.6,
|
score_threshold: float = 0.6,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
|
"""Search for matching entries in storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
filter: Optional metadata filter.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
collection_name = (
|
collection_name = (
|
||||||
@@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
)
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 5,
|
||||||
|
filter: dict[str, Any] | None = None,
|
||||||
|
score_threshold: float = 0.6,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Search for matching entries in storage asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
filter: Optional metadata filter.
|
||||||
|
score_threshold: Minimum similarity score for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client = self._get_client()
|
||||||
|
collection_name = (
|
||||||
|
f"memory_{self.type}_{self.agents}"
|
||||||
|
if self.agents
|
||||||
|
else f"memory_{self.type}"
|
||||||
|
)
|
||||||
|
return await client.asearch(
|
||||||
|
collection_name=collection_name,
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
metadata_filter=filter,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
try:
|
try:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
|||||||
496
lib/crewai/tests/memory/test_async_memory.py
Normal file
496
lib/crewai/tests/memory/test_async_memory.py
Normal file
@@ -0,0 +1,496 @@
|
|||||||
|
"""Tests for async memory operations."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
|
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
|
from crewai.events.types.memory_events import (
|
||||||
|
MemoryQueryCompletedEvent,
|
||||||
|
MemoryQueryStartedEvent,
|
||||||
|
MemorySaveCompletedEvent,
|
||||||
|
MemorySaveStartedEvent,
|
||||||
|
)
|
||||||
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||||
|
from crewai.memory.entity.entity_memory import EntityMemory
|
||||||
|
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||||
|
from crewai.memory.external.external_memory import ExternalMemory
|
||||||
|
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||||
|
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||||
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent():
|
||||||
|
"""Fixture to create a mock agent."""
|
||||||
|
return Agent(
|
||||||
|
role="Researcher",
|
||||||
|
goal="Search relevant data and provide results",
|
||||||
|
backstory="You are a researcher at a leading tech think tank.",
|
||||||
|
tools=[],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_task(mock_agent):
|
||||||
|
"""Fixture to create a mock task."""
|
||||||
|
return Task(
|
||||||
|
description="Perform a search on specific topics.",
|
||||||
|
expected_output="A list of relevant URLs based on the search query.",
|
||||||
|
agent=mock_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def short_term_memory(mock_agent, mock_task):
|
||||||
|
"""Fixture to create a ShortTermMemory instance."""
|
||||||
|
return ShortTermMemory(crew=Crew(agents=[mock_agent], tasks=[mock_task]))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def long_term_memory(tmp_path):
|
||||||
|
"""Fixture to create a LongTermMemory instance."""
|
||||||
|
db_path = str(tmp_path / "test_ltm.db")
|
||||||
|
return LongTermMemory(path=db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def entity_memory(tmp_path, mock_agent, mock_task):
|
||||||
|
"""Fixture to create an EntityMemory instance."""
|
||||||
|
return EntityMemory(
|
||||||
|
crew=Crew(agents=[mock_agent], tasks=[mock_task]),
|
||||||
|
path=str(tmp_path / "test_entities"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncShortTermMemory:
|
||||||
|
"""Tests for async ShortTermMemory operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_emits_events(self, short_term_memory):
|
||||||
|
"""Test that asave emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
condition = threading.Condition()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||||
|
def on_save_started(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveStartedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||||
|
def on_save_completed(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveCompletedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
await short_term_memory.asave(
|
||||||
|
value="async test value",
|
||||||
|
metadata={"task": "async_test_task"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with condition:
|
||||||
|
success = condition.wait_for(
|
||||||
|
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
assert success, "Timeout waiting for async save events"
|
||||||
|
|
||||||
|
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||||
|
assert events["MemorySaveStartedEvent"][-1].value == "async test value"
|
||||||
|
assert events["MemorySaveStartedEvent"][-1].source_type == "short_term_memory"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_emits_events(self, short_term_memory):
|
||||||
|
"""Test that asearch emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
search_started = threading.Event()
|
||||||
|
search_completed = threading.Event()
|
||||||
|
|
||||||
|
with patch.object(short_term_memory.storage, "asearch", new_callable=AsyncMock, return_value=[]):
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||||
|
def on_search_started(source, event):
|
||||||
|
events["MemoryQueryStartedEvent"].append(event)
|
||||||
|
search_started.set()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||||
|
def on_search_completed(source, event):
|
||||||
|
events["MemoryQueryCompletedEvent"].append(event)
|
||||||
|
search_completed.set()
|
||||||
|
|
||||||
|
await short_term_memory.asearch(
|
||||||
|
query="async test query",
|
||||||
|
limit=3,
|
||||||
|
score_threshold=0.35,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||||
|
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||||
|
|
||||||
|
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||||
|
assert events["MemoryQueryStartedEvent"][-1].query == "async test query"
|
||||||
|
assert events["MemoryQueryStartedEvent"][-1].source_type == "short_term_memory"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncLongTermMemory:
|
||||||
|
"""Tests for async LongTermMemory operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_emits_events(self, long_term_memory):
|
||||||
|
"""Test that asave emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
condition = threading.Condition()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||||
|
def on_save_started(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveStartedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||||
|
def on_save_completed(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveCompletedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
item = LongTermMemoryItem(
|
||||||
|
task="async test task",
|
||||||
|
agent="test_agent",
|
||||||
|
expected_output="test output",
|
||||||
|
datetime="2024-01-01T00:00:00",
|
||||||
|
quality=0.9,
|
||||||
|
metadata={"task": "async test task", "quality": 0.9},
|
||||||
|
)
|
||||||
|
|
||||||
|
await long_term_memory.asave(item)
|
||||||
|
|
||||||
|
with condition:
|
||||||
|
success = condition.wait_for(
|
||||||
|
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
assert success, "Timeout waiting for async save events"
|
||||||
|
|
||||||
|
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||||
|
assert events["MemorySaveStartedEvent"][-1].source_type == "long_term_memory"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_emits_events(self, long_term_memory):
|
||||||
|
"""Test that asearch emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
search_started = threading.Event()
|
||||||
|
search_completed = threading.Event()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||||
|
def on_search_started(source, event):
|
||||||
|
events["MemoryQueryStartedEvent"].append(event)
|
||||||
|
search_started.set()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||||
|
def on_search_completed(source, event):
|
||||||
|
events["MemoryQueryCompletedEvent"].append(event)
|
||||||
|
search_completed.set()
|
||||||
|
|
||||||
|
await long_term_memory.asearch(task="async test task", latest_n=3)
|
||||||
|
|
||||||
|
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||||
|
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||||
|
|
||||||
|
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||||
|
assert events["MemoryQueryStartedEvent"][-1].source_type == "long_term_memory"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_and_asearch_integration(self, long_term_memory):
|
||||||
|
"""Test that asave followed by asearch works correctly."""
|
||||||
|
item = LongTermMemoryItem(
|
||||||
|
task="integration test task",
|
||||||
|
agent="test_agent",
|
||||||
|
expected_output="test output",
|
||||||
|
datetime="2024-01-01T00:00:00",
|
||||||
|
quality=0.9,
|
||||||
|
metadata={"task": "integration test task", "quality": 0.9},
|
||||||
|
)
|
||||||
|
|
||||||
|
await long_term_memory.asave(item)
|
||||||
|
results = await long_term_memory.asearch(task="integration test task", latest_n=1)
|
||||||
|
|
||||||
|
assert results is not None
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["metadata"]["agent"] == "test_agent"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncEntityMemory:
|
||||||
|
"""Tests for async EntityMemory operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asave_single_item_emits_events(self, entity_memory):
|
||||||
|
"""Test that asave with a single item emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
condition = threading.Condition()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveStartedEvent)
|
||||||
|
def on_save_started(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveStartedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||||
|
def on_save_completed(source, event):
|
||||||
|
with condition:
|
||||||
|
events["MemorySaveCompletedEvent"].append(event)
|
||||||
|
condition.notify()
|
||||||
|
|
||||||
|
item = EntityMemoryItem(
|
||||||
|
name="TestEntity",
|
||||||
|
type="Person",
|
||||||
|
description="A test entity for async operations",
|
||||||
|
relationships="Related to other test entities",
|
||||||
|
)
|
||||||
|
|
||||||
|
await entity_memory.asave(item)
|
||||||
|
|
||||||
|
with condition:
|
||||||
|
success = condition.wait_for(
|
||||||
|
lambda: len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
and len(events["MemorySaveCompletedEvent"]) >= 1,
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
assert success, "Timeout waiting for async save events"
|
||||||
|
|
||||||
|
assert len(events["MemorySaveStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemorySaveCompletedEvent"]) >= 1
|
||||||
|
assert events["MemorySaveStartedEvent"][-1].source_type == "entity_memory"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_emits_events(self, entity_memory):
|
||||||
|
"""Test that asearch emits the correct events."""
|
||||||
|
events: dict[str, list] = defaultdict(list)
|
||||||
|
search_started = threading.Event()
|
||||||
|
search_completed = threading.Event()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryStartedEvent)
|
||||||
|
def on_search_started(source, event):
|
||||||
|
events["MemoryQueryStartedEvent"].append(event)
|
||||||
|
search_started.set()
|
||||||
|
|
||||||
|
@crewai_event_bus.on(MemoryQueryCompletedEvent)
|
||||||
|
def on_search_completed(source, event):
|
||||||
|
events["MemoryQueryCompletedEvent"].append(event)
|
||||||
|
search_completed.set()
|
||||||
|
|
||||||
|
await entity_memory.asearch(query="TestEntity", limit=5, score_threshold=0.6)
|
||||||
|
|
||||||
|
assert search_started.wait(timeout=2), "Timeout waiting for search started event"
|
||||||
|
assert search_completed.wait(timeout=2), "Timeout waiting for search completed event"
|
||||||
|
|
||||||
|
assert len(events["MemoryQueryStartedEvent"]) >= 1
|
||||||
|
assert len(events["MemoryQueryCompletedEvent"]) >= 1
|
||||||
|
assert events["MemoryQueryStartedEvent"][-1].source_type == "entity_memory"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncContextualMemory:
|
||||||
|
"""Tests for async ContextualMemory operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abuild_context_for_task_with_empty_query(self, mock_task):
|
||||||
|
"""Test that abuild_context_for_task returns empty string for empty query."""
|
||||||
|
mock_task.description = ""
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=None,
|
||||||
|
ltm=None,
|
||||||
|
em=None,
|
||||||
|
exm=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory.abuild_context_for_task(mock_task, "")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abuild_context_for_task_with_none_memories(self, mock_task):
|
||||||
|
"""Test that abuild_context_for_task handles None memory sources."""
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=None,
|
||||||
|
ltm=None,
|
||||||
|
em=None,
|
||||||
|
exm=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory.abuild_context_for_task(mock_task, "some context")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_abuild_context_for_task_aggregates_results(self, mock_agent, mock_task):
|
||||||
|
"""Test that abuild_context_for_task aggregates results from all memory sources."""
|
||||||
|
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||||
|
mock_stm.asearch = AsyncMock(return_value=[{"content": "STM insight"}])
|
||||||
|
|
||||||
|
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||||
|
mock_ltm.asearch = AsyncMock(
|
||||||
|
return_value=[{"metadata": {"suggestions": ["LTM suggestion"]}}]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_em = MagicMock(spec=EntityMemory)
|
||||||
|
mock_em.asearch = AsyncMock(return_value=[{"content": "Entity info"}])
|
||||||
|
|
||||||
|
mock_exm = MagicMock(spec=ExternalMemory)
|
||||||
|
mock_exm.asearch = AsyncMock(return_value=[{"content": "External memory"}])
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=mock_stm,
|
||||||
|
ltm=mock_ltm,
|
||||||
|
em=mock_em,
|
||||||
|
exm=mock_exm,
|
||||||
|
agent=mock_agent,
|
||||||
|
task=mock_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory.abuild_context_for_task(mock_task, "additional context")
|
||||||
|
|
||||||
|
assert "Recent Insights:" in result
|
||||||
|
assert "STM insight" in result
|
||||||
|
assert "Historical Data:" in result
|
||||||
|
assert "LTM suggestion" in result
|
||||||
|
assert "Entities:" in result
|
||||||
|
assert "Entity info" in result
|
||||||
|
assert "External memories:" in result
|
||||||
|
assert "External memory" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_afetch_stm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||||
|
"""Test that _afetch_stm_context returns properly formatted results."""
|
||||||
|
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||||
|
mock_stm.asearch = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"content": "First insight"},
|
||||||
|
{"content": "Second insight"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=mock_stm,
|
||||||
|
ltm=None,
|
||||||
|
em=None,
|
||||||
|
exm=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory._afetch_stm_context("test query")
|
||||||
|
|
||||||
|
assert "Recent Insights:" in result
|
||||||
|
assert "- First insight" in result
|
||||||
|
assert "- Second insight" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_afetch_ltm_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||||
|
"""Test that _afetch_ltm_context returns properly formatted results."""
|
||||||
|
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||||
|
mock_ltm.asearch = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"metadata": {"suggestions": ["Suggestion 1", "Suggestion 2"]}},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=None,
|
||||||
|
ltm=mock_ltm,
|
||||||
|
em=None,
|
||||||
|
exm=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory._afetch_ltm_context("test task")
|
||||||
|
|
||||||
|
assert "Historical Data:" in result
|
||||||
|
assert "- Suggestion 1" in result
|
||||||
|
assert "- Suggestion 2" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_afetch_entity_context_returns_formatted_results(self, mock_agent, mock_task):
|
||||||
|
"""Test that _afetch_entity_context returns properly formatted results."""
|
||||||
|
mock_em = MagicMock(spec=EntityMemory)
|
||||||
|
mock_em.asearch = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"content": "Entity A details"},
|
||||||
|
{"content": "Entity B details"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=None,
|
||||||
|
ltm=None,
|
||||||
|
em=mock_em,
|
||||||
|
exm=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory._afetch_entity_context("test query")
|
||||||
|
|
||||||
|
assert "Entities:" in result
|
||||||
|
assert "- Entity A details" in result
|
||||||
|
assert "- Entity B details" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_afetch_external_context_returns_formatted_results(self):
|
||||||
|
"""Test that _afetch_external_context returns properly formatted results."""
|
||||||
|
mock_exm = MagicMock(spec=ExternalMemory)
|
||||||
|
mock_exm.asearch = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
{"content": "External data 1"},
|
||||||
|
{"content": "External data 2"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=None,
|
||||||
|
ltm=None,
|
||||||
|
em=None,
|
||||||
|
exm=mock_exm,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await contextual_memory._afetch_external_context("test query")
|
||||||
|
|
||||||
|
assert "External memories:" in result
|
||||||
|
assert "- External data 1" in result
|
||||||
|
assert "- External data 2" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_afetch_methods_return_empty_for_empty_results(self):
|
||||||
|
"""Test that async fetch methods return empty string for no results."""
|
||||||
|
mock_stm = MagicMock(spec=ShortTermMemory)
|
||||||
|
mock_stm.asearch = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
mock_ltm = MagicMock(spec=LongTermMemory)
|
||||||
|
mock_ltm.asearch = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
mock_em = MagicMock(spec=EntityMemory)
|
||||||
|
mock_em.asearch = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
mock_exm = MagicMock(spec=ExternalMemory)
|
||||||
|
mock_exm.asearch = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
stm=mock_stm,
|
||||||
|
ltm=mock_ltm,
|
||||||
|
em=mock_em,
|
||||||
|
exm=mock_exm,
|
||||||
|
)
|
||||||
|
|
||||||
|
stm_result = await contextual_memory._afetch_stm_context("query")
|
||||||
|
ltm_result = await contextual_memory._afetch_ltm_context("task")
|
||||||
|
em_result = await contextual_memory._afetch_entity_context("query")
|
||||||
|
exm_result = await contextual_memory._afetch_external_context("query")
|
||||||
|
|
||||||
|
assert stm_result == ""
|
||||||
|
assert ltm_result is None
|
||||||
|
assert em_result == ""
|
||||||
|
assert exm_result == ""
|
||||||
Reference in New Issue
Block a user