From 441591d592b3ec5450cb3883ce50221d9ac263d7 Mon Sep 17 00:00:00 2001 From: Greyson Lalonde Date: Tue, 2 Dec 2025 13:09:52 -0500 Subject: [PATCH] feat: add async ops to memory feat; create tests --- .../memory/contextual/contextual_memory.py | 123 ++++- .../src/crewai/memory/entity/entity_memory.py | 184 ++++++- .../crewai/memory/external/external_memory.py | 137 ++++- .../memory/long_term/long_term_memory.py | 137 ++++- lib/crewai/src/crewai/memory/memory.py | 60 ++- .../memory/short_term/short_term_memory.py | 149 +++++- .../memory/storage/ltm_sqlite_storage.py | 108 +++- .../src/crewai/memory/storage/rag_storage.py | 100 ++++ lib/crewai/tests/memory/test_async_memory.py | 496 ++++++++++++++++++ 9 files changed, 1456 insertions(+), 38 deletions(-) create mode 100644 lib/crewai/tests/memory/test_async_memory.py diff --git a/lib/crewai/src/crewai/memory/contextual/contextual_memory.py b/lib/crewai/src/crewai/memory/contextual/contextual_memory.py index b65850c3c..5e35d4f2f 100644 --- a/lib/crewai/src/crewai/memory/contextual/contextual_memory.py +++ b/lib/crewai/src/crewai/memory/contextual/contextual_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING from crewai.memory import ( @@ -16,6 +17,8 @@ if TYPE_CHECKING: class ContextualMemory: + """Aggregates and retrieves context from multiple memory sources.""" + def __init__( self, stm: ShortTermMemory, @@ -46,9 +49,14 @@ class ContextualMemory: self.exm.task = self.task def build_context_for_task(self, task: Task, context: str) -> str: - """ - Automatically builds a minimal, highly relevant set of contextual information - for a given task. + """Build contextual information for a task synchronously. + + Args: + task: The task to build context for. + context: Additional context string. + + Returns: + Formatted context string from all memory sources. """ query = f"{task.description} {context}".strip() @@ -63,6 +71,31 @@ class ContextualMemory: ] return "\n".join(filter(None, context_parts)) + async def abuild_context_for_task(self, task: Task, context: str) -> str: + """Build contextual information for a task asynchronously. + + Args: + task: The task to build context for. + context: Additional context string. + + Returns: + Formatted context string from all memory sources. + """ + query = f"{task.description} {context}".strip() + + if query == "": + return "" + + # Fetch all contexts concurrently + results = await asyncio.gather( + self._afetch_ltm_context(task.description), + self._afetch_stm_context(query), + self._afetch_entity_context(query), + self._afetch_external_context(query), + ) + + return "\n".join(filter(None, results)) + def _fetch_stm_context(self, query: str) -> str: """ Fetches recent relevant insights from STM related to the task's description and expected_output, @@ -135,3 +168,87 @@ class ContextualMemory: f"- {result['content']}" for result in external_memories ) return f"External memories:\n{formatted_memories}" + + async def _afetch_stm_context(self, query: str) -> str: + """Fetch recent relevant insights from STM asynchronously. + + Args: + query: The search query. + + Returns: + Formatted insights as bullet points, or empty string if none found. + """ + if self.stm is None: + return "" + + stm_results = await self.stm.asearch(query) + formatted_results = "\n".join( + [f"- {result['content']}" for result in stm_results] + ) + return f"Recent Insights:\n{formatted_results}" if stm_results else "" + + async def _afetch_ltm_context(self, task: str) -> str | None: + """Fetch historical data from LTM asynchronously. + + Args: + task: The task description to search for. + + Returns: + Formatted historical data as bullet points, or None if none found. + """ + if self.ltm is None: + return "" + + ltm_results = await self.ltm.asearch(task, latest_n=2) + if not ltm_results: + return None + + formatted_results = [ + suggestion + for result in ltm_results + for suggestion in result["metadata"]["suggestions"] + ] + formatted_results = list(dict.fromkeys(formatted_results)) + formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]") + + return f"Historical Data:\n{formatted_results}" if ltm_results else "" + + async def _afetch_entity_context(self, query: str) -> str: + """Fetch relevant entity information asynchronously. + + Args: + query: The search query. + + Returns: + Formatted entity information as bullet points, or empty string if none found. + """ + if self.em is None: + return "" + + em_results = await self.em.asearch(query) + formatted_results = "\n".join( + [f"- {result['content']}" for result in em_results] + ) + return f"Entities:\n{formatted_results}" if em_results else "" + + async def _afetch_external_context(self, query: str) -> str: + """Fetch relevant information from External Memory asynchronously. + + Args: + query: The search query. + + Returns: + Formatted information as bullet points, or empty string if none found. + """ + if self.exm is None: + return "" + + external_memories = await self.exm.asearch(query) + + if not external_memories: + return "" + + formatted_memories = "\n".join( + f"- {result['content']}" for result in external_memories + ) + return f"External memories:\n{formatted_memories}" diff --git a/lib/crewai/src/crewai/memory/entity/entity_memory.py b/lib/crewai/src/crewai/memory/entity/entity_memory.py index 18a08809e..b3e3a568b 100644 --- a/lib/crewai/src/crewai/memory/entity/entity_memory.py +++ b/lib/crewai/src/crewai/memory/entity/entity_memory.py @@ -26,7 +26,13 @@ class EntityMemory(Memory): _memory_provider: str | None = PrivateAttr() - def __init__(self, crew=None, embedder_config=None, storage=None, path=None): + def __init__( + self, + crew: Any = None, + embedder_config: Any = None, + storage: Any = None, + path: str | None = None, + ) -> None: memory_provider = None if embedder_config and isinstance(embedder_config, dict): memory_provider = embedder_config.get("provider") @@ -43,7 +49,7 @@ class EntityMemory(Memory): if embedder_config and isinstance(embedder_config, dict) else None ) - storage = Mem0Storage(type="short_term", crew=crew, config=config) + storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call] else: storage = ( storage @@ -170,7 +176,17 @@ class EntityMemory(Memory): query: str, limit: int = 5, score_threshold: float = 0.6, - ): + ) -> list[Any]: + """Search entity memory for relevant entries. + + Args: + query: The search query. + limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + + Returns: + List of matching memory entries. + """ crewai_event_bus.emit( self, event=MemoryQueryStartedEvent( @@ -217,6 +233,168 @@ class EntityMemory(Memory): ) raise + async def asave( + self, + value: EntityMemoryItem | list[EntityMemoryItem], + metadata: dict[str, Any] | None = None, + ) -> None: + """Save entity items asynchronously. + + Args: + value: Single EntityMemoryItem or list of EntityMemoryItems to save. + metadata: Optional metadata dict (not used, for signature compatibility). + """ + if not value: + return + + items = value if isinstance(value, list) else [value] + is_batch = len(items) > 1 + + metadata = {"entity_count": len(items)} if is_batch else items[0].metadata + crewai_event_bus.emit( + self, + event=MemorySaveStartedEvent( + metadata=metadata, + source_type="entity_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + start_time = time.time() + saved_count = 0 + errors: list[str | None] = [] + + async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]: + """Save a single item asynchronously.""" + try: + if self._memory_provider == "mem0": + data = f""" + Remember details about the following entity: + Name: {item.name} + Type: {item.type} + Entity Description: {item.description} + """ + else: + data = f"{item.name}({item.type}): {item.description}" + + await super(EntityMemory, self).asave(data, item.metadata) + return True, None + except Exception as e: + return False, f"{item.name}: {e!s}" + + try: + for item in items: + success, error = await save_single_item(item) + if success: + saved_count += 1 + else: + errors.append(error) + + if is_batch: + emit_value = f"Saved {saved_count} entities" + metadata = {"entity_count": saved_count, "errors": errors} + else: + emit_value = f"{items[0].name}({items[0].type}): {items[0].description}" + metadata = items[0].metadata + + crewai_event_bus.emit( + self, + event=MemorySaveCompletedEvent( + value=emit_value, + metadata=metadata, + save_time_ms=(time.time() - start_time) * 1000, + source_type="entity_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + if errors: + raise Exception( + f"Partial save: {len(errors)} failed out of {len(items)}" + ) + + except Exception as e: + fail_metadata = ( + {"entity_count": len(items), "saved": saved_count} + if is_batch + else items[0].metadata + ) + crewai_event_bus.emit( + self, + event=MemorySaveFailedEvent( + metadata=fail_metadata, + error=str(e), + source_type="entity_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + raise + + async def asearch( + self, + query: str, + limit: int = 5, + score_threshold: float = 0.6, + ) -> list[Any]: + """Search entity memory asynchronously. + + Args: + query: The search query. + limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + + Returns: + List of matching memory entries. + """ + crewai_event_bus.emit( + self, + event=MemoryQueryStartedEvent( + query=query, + limit=limit, + score_threshold=score_threshold, + source_type="entity_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + start_time = time.time() + try: + results = await super().asearch( + query=query, limit=limit, score_threshold=score_threshold + ) + + crewai_event_bus.emit( + self, + event=MemoryQueryCompletedEvent( + query=query, + results=results, + limit=limit, + score_threshold=score_threshold, + query_time_ms=(time.time() - start_time) * 1000, + source_type="entity_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + return results + except Exception as e: + crewai_event_bus.emit( + self, + event=MemoryQueryFailedEvent( + query=query, + limit=limit, + score_threshold=score_threshold, + error=str(e), + source_type="entity_memory", + ), + ) + raise + def reset(self) -> None: try: self.storage.reset() diff --git a/lib/crewai/src/crewai/memory/external/external_memory.py b/lib/crewai/src/crewai/memory/external/external_memory.py index c48ffd1e3..6aedf0084 100644 --- a/lib/crewai/src/crewai/memory/external/external_memory.py +++ b/lib/crewai/src/crewai/memory/external/external_memory.py @@ -30,7 +30,7 @@ class ExternalMemory(Memory): def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage: from crewai.memory.storage.mem0_storage import Mem0Storage - return Mem0Storage(type="external", crew=crew, config=config) + return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call] @staticmethod def external_supported_storages() -> dict[str, Any]: @@ -53,7 +53,10 @@ class ExternalMemory(Memory): if provider not in supported_storages: raise ValueError(f"Provider {provider} not supported") - return supported_storages[provider](crew, embedder_config.get("config", {})) + storage: Storage = supported_storages[provider]( + crew, embedder_config.get("config", {}) + ) + return storage def save( self, @@ -111,7 +114,17 @@ class ExternalMemory(Memory): query: str, limit: int = 5, score_threshold: float = 0.6, - ): + ) -> list[Any]: + """Search external memory for relevant entries. + + Args: + query: The search query. + limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + + Returns: + List of matching memory entries. + """ crewai_event_bus.emit( self, event=MemoryQueryStartedEvent( @@ -158,6 +171,124 @@ class ExternalMemory(Memory): ) raise + async def asave( + self, + value: Any, + metadata: dict[str, Any] | None = None, + ) -> None: + """Save a value to external memory asynchronously. + + Args: + value: The value to save. + metadata: Optional metadata to associate with the value. + """ + crewai_event_bus.emit( + self, + event=MemorySaveStartedEvent( + value=value, + metadata=metadata, + source_type="external_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + start_time = time.time() + try: + item = ExternalMemoryItem( + value=value, + metadata=metadata, + agent=self.agent.role if self.agent else None, + ) + await super().asave(value=item.value, metadata=item.metadata) + + crewai_event_bus.emit( + self, + event=MemorySaveCompletedEvent( + value=value, + metadata=metadata, + save_time_ms=(time.time() - start_time) * 1000, + source_type="external_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + except Exception as e: + crewai_event_bus.emit( + self, + event=MemorySaveFailedEvent( + value=value, + metadata=metadata, + error=str(e), + source_type="external_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + raise + + async def asearch( + self, + query: str, + limit: int = 5, + score_threshold: float = 0.6, + ) -> list[Any]: + """Search external memory asynchronously. + + Args: + query: The search query. + limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + + Returns: + List of matching memory entries. + """ + crewai_event_bus.emit( + self, + event=MemoryQueryStartedEvent( + query=query, + limit=limit, + score_threshold=score_threshold, + source_type="external_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + start_time = time.time() + try: + results = await super().asearch( + query=query, limit=limit, score_threshold=score_threshold + ) + + crewai_event_bus.emit( + self, + event=MemoryQueryCompletedEvent( + query=query, + results=results, + limit=limit, + score_threshold=score_threshold, + query_time_ms=(time.time() - start_time) * 1000, + source_type="external_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + return results + except Exception as e: + crewai_event_bus.emit( + self, + event=MemoryQueryFailedEvent( + query=query, + limit=limit, + score_threshold=score_threshold, + error=str(e), + source_type="external_memory", + ), + ) + raise + def reset(self) -> None: self.storage.reset() diff --git a/lib/crewai/src/crewai/memory/long_term/long_term_memory.py b/lib/crewai/src/crewai/memory/long_term/long_term_memory.py index 038d07e83..35ab12870 100644 --- a/lib/crewai/src/crewai/memory/long_term/long_term_memory.py +++ b/lib/crewai/src/crewai/memory/long_term/long_term_memory.py @@ -24,7 +24,11 @@ class LongTermMemory(Memory): LongTermMemoryItem instances. """ - def __init__(self, storage=None, path=None): + def __init__( + self, + storage: LTMSQLiteStorage | None = None, + path: str | None = None, + ) -> None: if not storage: storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() super().__init__(storage=storage) @@ -48,7 +52,7 @@ class LongTermMemory(Memory): metadata.update( {"agent": item.agent, "expected_output": item.expected_output} ) - self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage" + self.storage.save( task_description=item.task, score=metadata["quality"], metadata=metadata, @@ -80,11 +84,20 @@ class LongTermMemory(Memory): ) raise - def search( # type: ignore # signature of "search" incompatible with supertype "Memory" + def search( # type: ignore[override] self, task: str, latest_n: int = 3, - ) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory" + ) -> list[dict[str, Any]]: + """Search long-term memory for relevant entries. + + Args: + task: The task description to search for. + latest_n: Maximum number of results to return. + + Returns: + List of matching memory entries. + """ crewai_event_bus.emit( self, event=MemoryQueryStartedEvent( @@ -98,7 +111,7 @@ class LongTermMemory(Memory): start_time = time.time() try: - results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load" + results = self.storage.load(task, latest_n) crewai_event_bus.emit( self, @@ -113,7 +126,118 @@ class LongTermMemory(Memory): ), ) - return results + return results or [] + except Exception as e: + crewai_event_bus.emit( + self, + event=MemoryQueryFailedEvent( + query=task, + limit=latest_n, + error=str(e), + source_type="long_term_memory", + ), + ) + raise + + async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override] + """Save an item to long-term memory asynchronously. + + Args: + item: The LongTermMemoryItem to save. + """ + crewai_event_bus.emit( + self, + event=MemorySaveStartedEvent( + value=item.task, + metadata=item.metadata, + agent_role=item.agent, + source_type="long_term_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + start_time = time.time() + try: + metadata = item.metadata + metadata.update( + {"agent": item.agent, "expected_output": item.expected_output} + ) + await self.storage.asave( + task_description=item.task, + score=metadata["quality"], + metadata=metadata, + datetime=item.datetime, + ) + + crewai_event_bus.emit( + self, + event=MemorySaveCompletedEvent( + value=item.task, + metadata=item.metadata, + agent_role=item.agent, + save_time_ms=(time.time() - start_time) * 1000, + source_type="long_term_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + except Exception as e: + crewai_event_bus.emit( + self, + event=MemorySaveFailedEvent( + value=item.task, + metadata=item.metadata, + agent_role=item.agent, + error=str(e), + source_type="long_term_memory", + ), + ) + raise + + async def asearch( # type: ignore[override] + self, + task: str, + latest_n: int = 3, + ) -> list[dict[str, Any]]: + """Search long-term memory asynchronously. + + Args: + task: The task description to search for. + latest_n: Maximum number of results to return. + + Returns: + List of matching memory entries. + """ + crewai_event_bus.emit( + self, + event=MemoryQueryStartedEvent( + query=task, + limit=latest_n, + source_type="long_term_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + start_time = time.time() + try: + results = await self.storage.aload(task, latest_n) + + crewai_event_bus.emit( + self, + event=MemoryQueryCompletedEvent( + query=task, + results=results, + limit=latest_n, + query_time_ms=(time.time() - start_time) * 1000, + source_type="long_term_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + return results or [] except Exception as e: crewai_event_bus.emit( self, @@ -127,4 +251,5 @@ class LongTermMemory(Memory): raise def reset(self) -> None: + """Reset long-term memory.""" self.storage.reset() diff --git a/lib/crewai/src/crewai/memory/memory.py b/lib/crewai/src/crewai/memory/memory.py index 74297f9e4..fe90b8e3e 100644 --- a/lib/crewai/src/crewai/memory/memory.py +++ b/lib/crewai/src/crewai/memory/memory.py @@ -13,9 +13,7 @@ if TYPE_CHECKING: class Memory(BaseModel): - """ - Base class for memory, now supporting agent tags and generic metadata. - """ + """Base class for memory, supporting agent tags and generic metadata.""" embedder_config: EmbedderConfig | dict[str, Any] | None = None crew: Any | None = None @@ -52,20 +50,72 @@ class Memory(BaseModel): value: Any, metadata: dict[str, Any] | None = None, ) -> None: - metadata = metadata or {} + """Save a value to memory. + Args: + value: The value to save. + metadata: Optional metadata to associate with the value. + """ + metadata = metadata or {} self.storage.save(value, metadata) + async def asave( + self, + value: Any, + metadata: dict[str, Any] | None = None, + ) -> None: + """Save a value to memory asynchronously. + + Args: + value: The value to save. + metadata: Optional metadata to associate with the value. + """ + metadata = metadata or {} + await self.storage.asave(value, metadata) + def search( self, query: str, limit: int = 5, score_threshold: float = 0.6, ) -> list[Any]: - return self.storage.search( + """Search memory for relevant entries. + + Args: + query: The search query. + limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + + Returns: + List of matching memory entries. + """ + results: list[Any] = self.storage.search( query=query, limit=limit, score_threshold=score_threshold ) + return results + + async def asearch( + self, + query: str, + limit: int = 5, + score_threshold: float = 0.6, + ) -> list[Any]: + """Search memory for relevant entries asynchronously. + + Args: + query: The search query. + limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + + Returns: + List of matching memory entries. + """ + results: list[Any] = await self.storage.asearch( + query=query, limit=limit, score_threshold=score_threshold + ) + return results def set_crew(self, crew: Any) -> Memory: + """Set the crew for this memory instance.""" self.crew = crew return self diff --git a/lib/crewai/src/crewai/memory/short_term/short_term_memory.py b/lib/crewai/src/crewai/memory/short_term/short_term_memory.py index 5bc9ec604..c1663b4f5 100644 --- a/lib/crewai/src/crewai/memory/short_term/short_term_memory.py +++ b/lib/crewai/src/crewai/memory/short_term/short_term_memory.py @@ -30,7 +30,13 @@ class ShortTermMemory(Memory): _memory_provider: str | None = PrivateAttr() - def __init__(self, crew=None, embedder_config=None, storage=None, path=None): + def __init__( + self, + crew: Any = None, + embedder_config: Any = None, + storage: Any = None, + path: str | None = None, + ) -> None: memory_provider = None if embedder_config and isinstance(embedder_config, dict): memory_provider = embedder_config.get("provider") @@ -47,7 +53,7 @@ class ShortTermMemory(Memory): if embedder_config and isinstance(embedder_config, dict) else None ) - storage = Mem0Storage(type="short_term", crew=crew, config=config) + storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call] else: storage = ( storage @@ -123,7 +129,17 @@ class ShortTermMemory(Memory): query: str, limit: int = 5, score_threshold: float = 0.6, - ): + ) -> list[Any]: + """Search short-term memory for relevant entries. + + Args: + query: The search query. + limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + + Returns: + List of matching memory entries. + """ crewai_event_bus.emit( self, event=MemoryQueryStartedEvent( @@ -140,7 +156,7 @@ class ShortTermMemory(Memory): try: results = self.storage.search( query=query, limit=limit, score_threshold=score_threshold - ) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters + ) crewai_event_bus.emit( self, @@ -156,7 +172,130 @@ class ShortTermMemory(Memory): ), ) - return results + return list(results) + except Exception as e: + crewai_event_bus.emit( + self, + event=MemoryQueryFailedEvent( + query=query, + limit=limit, + score_threshold=score_threshold, + error=str(e), + source_type="short_term_memory", + ), + ) + raise + + async def asave( + self, + value: Any, + metadata: dict[str, Any] | None = None, + ) -> None: + """Save a value to short-term memory asynchronously. + + Args: + value: The value to save. + metadata: Optional metadata to associate with the value. + """ + crewai_event_bus.emit( + self, + event=MemorySaveStartedEvent( + value=value, + metadata=metadata, + source_type="short_term_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + start_time = time.time() + try: + item = ShortTermMemoryItem( + data=value, + metadata=metadata, + agent=self.agent.role if self.agent else None, + ) + if self._memory_provider == "mem0": + item.data = ( + f"Remember the following insights from Agent run: {item.data}" + ) + + await super().asave(value=item.data, metadata=item.metadata) + + crewai_event_bus.emit( + self, + event=MemorySaveCompletedEvent( + value=value, + metadata=metadata, + save_time_ms=(time.time() - start_time) * 1000, + source_type="short_term_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + except Exception as e: + crewai_event_bus.emit( + self, + event=MemorySaveFailedEvent( + value=value, + metadata=metadata, + error=str(e), + source_type="short_term_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + raise + + async def asearch( + self, + query: str, + limit: int = 5, + score_threshold: float = 0.6, + ) -> list[Any]: + """Search short-term memory asynchronously. + + Args: + query: The search query. + limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + + Returns: + List of matching memory entries. + """ + crewai_event_bus.emit( + self, + event=MemoryQueryStartedEvent( + query=query, + limit=limit, + score_threshold=score_threshold, + source_type="short_term_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + start_time = time.time() + try: + results = await self.storage.asearch( + query=query, limit=limit, score_threshold=score_threshold + ) + + crewai_event_bus.emit( + self, + event=MemoryQueryCompletedEvent( + query=query, + results=results, + limit=limit, + score_threshold=score_threshold, + query_time_ms=(time.time() - start_time) * 1000, + source_type="short_term_memory", + from_agent=self.agent, + from_task=self.task, + ), + ) + + return list(results) except Exception as e: crewai_event_bus.emit( self, diff --git a/lib/crewai/src/crewai/memory/storage/ltm_sqlite_storage.py b/lib/crewai/src/crewai/memory/storage/ltm_sqlite_storage.py index 99895db38..bf4f6c738 100644 --- a/lib/crewai/src/crewai/memory/storage/ltm_sqlite_storage.py +++ b/lib/crewai/src/crewai/memory/storage/ltm_sqlite_storage.py @@ -3,29 +3,30 @@ from pathlib import Path import sqlite3 from typing import Any +import aiosqlite + from crewai.utilities import Printer from crewai.utilities.paths import db_storage_path class LTMSQLiteStorage: - """ - An updated SQLite storage class for LTM data storage. - """ + """SQLite storage class for long-term memory data.""" def __init__(self, db_path: str | None = None) -> None: + """Initialize the SQLite storage. + + Args: + db_path: Optional path to the database file. + """ if db_path is None: - # Get the parent directory of the default db path and create our db file there db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db") self.db_path = db_path self._printer: Printer = Printer() - # Ensure parent directory exists Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) self._initialize_db() - def _initialize_db(self): - """ - Initializes the SQLite database and creates LTM table - """ + def _initialize_db(self) -> None: + """Initialize the SQLite database and create LTM table.""" try: with sqlite3.connect(self.db_path) as conn: cursor = conn.cursor() @@ -106,9 +107,7 @@ class LTMSQLiteStorage: ) return None - def reset( - self, - ) -> None: + def reset(self) -> None: """Resets the LTM table with error handling.""" try: with sqlite3.connect(self.db_path) as conn: @@ -121,4 +120,87 @@ class LTMSQLiteStorage: content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}", color="red", ) - return + + async def asave( + self, + task_description: str, + metadata: dict[str, Any], + datetime: str, + score: int | float, + ) -> None: + """Save data to the LTM table asynchronously. + + Args: + task_description: Description of the task. + metadata: Metadata associated with the memory. + datetime: Timestamp of the memory. + score: Quality score of the memory. + """ + try: + async with aiosqlite.connect(self.db_path) as conn: + await conn.execute( + """ + INSERT INTO long_term_memories (task_description, metadata, datetime, score) + VALUES (?, ?, ?, ?) + """, + (task_description, json.dumps(metadata), datetime, score), + ) + await conn.commit() + except aiosqlite.Error as e: + self._printer.print( + content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}", + color="red", + ) + + async def aload( + self, task_description: str, latest_n: int + ) -> list[dict[str, Any]] | None: + """Query the LTM table by task description asynchronously. + + Args: + task_description: Description of the task to search for. + latest_n: Maximum number of results to return. + + Returns: + List of matching memory entries or None if error occurs. + """ + try: + async with aiosqlite.connect(self.db_path) as conn: + cursor = await conn.execute( + f""" + SELECT metadata, datetime, score + FROM long_term_memories + WHERE task_description = ? + ORDER BY datetime DESC, score ASC + LIMIT {latest_n} + """, # nosec # noqa: S608 + (task_description,), + ) + rows = await cursor.fetchall() + if rows: + return [ + { + "metadata": json.loads(row[0]), + "datetime": row[1], + "score": row[2], + } + for row in rows + ] + except aiosqlite.Error as e: + self._printer.print( + content=f"MEMORY ERROR: An error occurred while querying LTM: {e}", + color="red", + ) + return None + + async def areset(self) -> None: + """Reset the LTM table asynchronously.""" + try: + async with aiosqlite.connect(self.db_path) as conn: + await conn.execute("DELETE FROM long_term_memories") + await conn.commit() + except aiosqlite.Error as e: + self._printer.print( + content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}", + color="red", + ) diff --git a/lib/crewai/src/crewai/memory/storage/rag_storage.py b/lib/crewai/src/crewai/memory/storage/rag_storage.py index 2dabc9bca..b45cde55a 100644 --- a/lib/crewai/src/crewai/memory/storage/rag_storage.py +++ b/lib/crewai/src/crewai/memory/storage/rag_storage.py @@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage): return f"{base_path}/{file_name}" def save(self, value: Any, metadata: dict[str, Any]) -> None: + """Save a value to storage. + + Args: + value: The value to save. + metadata: Metadata to associate with the value. + """ try: client = self._get_client() collection_name = ( @@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage): f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}" ) + async def asave(self, value: Any, metadata: dict[str, Any]) -> None: + """Save a value to storage asynchronously. + + Args: + value: The value to save. + metadata: Metadata to associate with the value. + """ + try: + client = self._get_client() + collection_name = ( + f"memory_{self.type}_{self.agents}" + if self.agents + else f"memory_{self.type}" + ) + await client.aget_or_create_collection(collection_name=collection_name) + + document: BaseRecord = {"content": value} + if metadata: + document["metadata"] = metadata + + batch_size = None + if ( + self.embedder_config + and isinstance(self.embedder_config, dict) + and "config" in self.embedder_config + ): + nested_config = self.embedder_config["config"] + if isinstance(nested_config, dict): + batch_size = nested_config.get("batch_size") + + if batch_size is not None: + await client.aadd_documents( + collection_name=collection_name, + documents=[document], + batch_size=cast(int, batch_size), + ) + else: + await client.aadd_documents( + collection_name=collection_name, documents=[document] + ) + except Exception as e: + logging.error( + f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}" + ) + def search( self, query: str, @@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage): filter: dict[str, Any] | None = None, score_threshold: float = 0.6, ) -> list[Any]: + """Search for matching entries in storage. + + Args: + query: The search query. + limit: Maximum number of results to return. + filter: Optional metadata filter. + score_threshold: Minimum similarity score for results. + + Returns: + List of matching entries. + """ try: client = self._get_client() collection_name = ( @@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage): ) return [] + async def asearch( + self, + query: str, + limit: int = 5, + filter: dict[str, Any] | None = None, + score_threshold: float = 0.6, + ) -> list[Any]: + """Search for matching entries in storage asynchronously. + + Args: + query: The search query. + limit: Maximum number of results to return. + filter: Optional metadata filter. + score_threshold: Minimum similarity score for results. + + Returns: + List of matching entries. + """ + try: + client = self._get_client() + collection_name = ( + f"memory_{self.type}_{self.agents}" + if self.agents + else f"memory_{self.type}" + ) + return await client.asearch( + collection_name=collection_name, + query=query, + limit=limit, + metadata_filter=filter, + score_threshold=score_threshold, + ) + except Exception as e: + logging.error( + f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}" + ) + return [] + def reset(self) -> None: try: client = self._get_client() diff --git a/lib/crewai/tests/memory/test_async_memory.py b/lib/crewai/tests/memory/test_async_memory.py new file mode 100644 index 000000000..15c4c33eb --- /dev/null +++ b/lib/crewai/tests/memory/test_async_memory.py @@ -0,0 +1,496 @@ +"""Tests for async memory operations.""" + +import threading +from collections import defaultdict +from unittest.mock import ANY, AsyncMock, MagicMock, patch + +import pytest + +from crewai.agent import Agent +from crewai.crew import Crew +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.memory_events import ( + MemoryQueryCompletedEvent, + MemoryQueryStartedEvent, + MemorySaveCompletedEvent, + MemorySaveStartedEvent, +) +from crewai.memory.contextual.contextual_memory import ContextualMemory +from crewai.memory.entity.entity_memory import EntityMemory +from crewai.memory.entity.entity_memory_item import EntityMemoryItem +from crewai.memory.external.external_memory import ExternalMemory +from crewai.memory.long_term.long_term_memory import LongTermMemory +from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem +from crewai.memory.short_term.short_term_memory import ShortTermMemory +from crewai.task import Task + + +@pytest.fixture +def mock_agent(): + """Fixture to create a mock agent.""" + return Agent( + role="Researcher", + goal="Search relevant data and provide results", + backstory="You are a researcher at a leading tech think tank.", + tools=[], + verbose=True, + ) + + +@pytest.fixture +def mock_task(mock_agent): + """Fixture to create a mock task.""" + return Task( + description="Perform a search on specific topics.", + expected_output="A list of relevant URLs based on the search query.", + agent=mock_agent, + ) + + +@pytest.fixture +def short_term_memory(mock_agent, mock_task): + """Fixture to create a ShortTermMemory instance.""" + return ShortTermMemory(crew=Crew(agents=[mock_agent], tasks=[mock_task])) + + +@pytest.fixture +def long_term_memory(tmp_path): + """Fixture to create a LongTermMemory instance.""" + db_path = str(tmp_path / "test_ltm.db") + return LongTermMemory(path=db_path) + + +@pytest.fixture +def entity_memory(tmp_path, mock_agent, mock_task): + """Fixture to create an EntityMemory instance.""" + return EntityMemory( + crew=Crew(agents=[mock_agent], tasks=[mock_task]), + path=str(tmp_path / "test_entities"), + ) + + +class TestAsyncShortTermMemory: + """Tests for async ShortTermMemory operations.""" + + @pytest.mark.asyncio + async def test_asave_emits_events(self, short_term_memory): + """Test that asave emits the correct events.""" + events: dict[str, list] = defaultdict(list) + condition = threading.Condition() + + @crewai_event_bus.on(MemorySaveStartedEvent) + def on_save_started(source, event): + with condition: + events["MemorySaveStartedEvent"].append(event) + condition.notify() + + @crewai_event_bus.on(MemorySaveCompletedEvent) + def on_save_completed(source, event): + with condition: + events["MemorySaveCompletedEvent"].append(event) + condition.notify() + + await short_term_memory.asave( + value="async test value", + metadata={"task": "async_test_task"}, + ) + + with condition: + success = condition.wait_for( + lambda: len(events["MemorySaveStartedEvent"]) >= 1 + and len(events["MemorySaveCompletedEvent"]) >= 1, + timeout=5, + ) + assert success, "Timeout waiting for async save events" + + assert len(events["MemorySaveStartedEvent"]) >= 1 + assert len(events["MemorySaveCompletedEvent"]) >= 1 + assert events["MemorySaveStartedEvent"][-1].value == "async test value" + assert events["MemorySaveStartedEvent"][-1].source_type == "short_term_memory" + + @pytest.mark.asyncio + async def test_asearch_emits_events(self, short_term_memory): + """Test that asearch emits the correct events.""" + events: dict[str, list] = defaultdict(list) + search_started = threading.Event() + search_completed = threading.Event() + + with patch.object(short_term_memory.storage, "asearch", new_callable=AsyncMock, return_value=[]): + + @crewai_event_bus.on(MemoryQueryStartedEvent) + def on_search_started(source, event): + events["MemoryQueryStartedEvent"].append(event) + search_started.set() + + @crewai_event_bus.on(MemoryQueryCompletedEvent) + def on_search_completed(source, event): + events["MemoryQueryCompletedEvent"].append(event) + search_completed.set() + + await short_term_memory.asearch( + query="async test query", + limit=3, + score_threshold=0.35, + ) + + assert search_started.wait(timeout=2), "Timeout waiting for search started event" + assert search_completed.wait(timeout=2), "Timeout waiting for search completed event" + + assert len(events["MemoryQueryStartedEvent"]) >= 1 + assert len(events["MemoryQueryCompletedEvent"]) >= 1 + assert events["MemoryQueryStartedEvent"][-1].query == "async test query" + assert events["MemoryQueryStartedEvent"][-1].source_type == "short_term_memory" + + +class TestAsyncLongTermMemory: + """Tests for async LongTermMemory operations.""" + + @pytest.mark.asyncio + async def test_asave_emits_events(self, long_term_memory): + """Test that asave emits the correct events.""" + events: dict[str, list] = defaultdict(list) + condition = threading.Condition() + + @crewai_event_bus.on(MemorySaveStartedEvent) + def on_save_started(source, event): + with condition: + events["MemorySaveStartedEvent"].append(event) + condition.notify() + + @crewai_event_bus.on(MemorySaveCompletedEvent) + def on_save_completed(source, event): + with condition: + events["MemorySaveCompletedEvent"].append(event) + condition.notify() + + item = LongTermMemoryItem( + task="async test task", + agent="test_agent", + expected_output="test output", + datetime="2024-01-01T00:00:00", + quality=0.9, + metadata={"task": "async test task", "quality": 0.9}, + ) + + await long_term_memory.asave(item) + + with condition: + success = condition.wait_for( + lambda: len(events["MemorySaveStartedEvent"]) >= 1 + and len(events["MemorySaveCompletedEvent"]) >= 1, + timeout=5, + ) + assert success, "Timeout waiting for async save events" + + assert len(events["MemorySaveStartedEvent"]) >= 1 + assert len(events["MemorySaveCompletedEvent"]) >= 1 + assert events["MemorySaveStartedEvent"][-1].source_type == "long_term_memory" + + @pytest.mark.asyncio + async def test_asearch_emits_events(self, long_term_memory): + """Test that asearch emits the correct events.""" + events: dict[str, list] = defaultdict(list) + search_started = threading.Event() + search_completed = threading.Event() + + @crewai_event_bus.on(MemoryQueryStartedEvent) + def on_search_started(source, event): + events["MemoryQueryStartedEvent"].append(event) + search_started.set() + + @crewai_event_bus.on(MemoryQueryCompletedEvent) + def on_search_completed(source, event): + events["MemoryQueryCompletedEvent"].append(event) + search_completed.set() + + await long_term_memory.asearch(task="async test task", latest_n=3) + + assert search_started.wait(timeout=2), "Timeout waiting for search started event" + assert search_completed.wait(timeout=2), "Timeout waiting for search completed event" + + assert len(events["MemoryQueryStartedEvent"]) >= 1 + assert len(events["MemoryQueryCompletedEvent"]) >= 1 + assert events["MemoryQueryStartedEvent"][-1].source_type == "long_term_memory" + + @pytest.mark.asyncio + async def test_asave_and_asearch_integration(self, long_term_memory): + """Test that asave followed by asearch works correctly.""" + item = LongTermMemoryItem( + task="integration test task", + agent="test_agent", + expected_output="test output", + datetime="2024-01-01T00:00:00", + quality=0.9, + metadata={"task": "integration test task", "quality": 0.9}, + ) + + await long_term_memory.asave(item) + results = await long_term_memory.asearch(task="integration test task", latest_n=1) + + assert results is not None + assert len(results) == 1 + assert results[0]["metadata"]["agent"] == "test_agent" + + +class TestAsyncEntityMemory: + """Tests for async EntityMemory operations.""" + + @pytest.mark.asyncio + async def test_asave_single_item_emits_events(self, entity_memory): + """Test that asave with a single item emits the correct events.""" + events: dict[str, list] = defaultdict(list) + condition = threading.Condition() + + @crewai_event_bus.on(MemorySaveStartedEvent) + def on_save_started(source, event): + with condition: + events["MemorySaveStartedEvent"].append(event) + condition.notify() + + @crewai_event_bus.on(MemorySaveCompletedEvent) + def on_save_completed(source, event): + with condition: + events["MemorySaveCompletedEvent"].append(event) + condition.notify() + + item = EntityMemoryItem( + name="TestEntity", + type="Person", + description="A test entity for async operations", + relationships="Related to other test entities", + ) + + await entity_memory.asave(item) + + with condition: + success = condition.wait_for( + lambda: len(events["MemorySaveStartedEvent"]) >= 1 + and len(events["MemorySaveCompletedEvent"]) >= 1, + timeout=5, + ) + assert success, "Timeout waiting for async save events" + + assert len(events["MemorySaveStartedEvent"]) >= 1 + assert len(events["MemorySaveCompletedEvent"]) >= 1 + assert events["MemorySaveStartedEvent"][-1].source_type == "entity_memory" + + @pytest.mark.asyncio + async def test_asearch_emits_events(self, entity_memory): + """Test that asearch emits the correct events.""" + events: dict[str, list] = defaultdict(list) + search_started = threading.Event() + search_completed = threading.Event() + + @crewai_event_bus.on(MemoryQueryStartedEvent) + def on_search_started(source, event): + events["MemoryQueryStartedEvent"].append(event) + search_started.set() + + @crewai_event_bus.on(MemoryQueryCompletedEvent) + def on_search_completed(source, event): + events["MemoryQueryCompletedEvent"].append(event) + search_completed.set() + + await entity_memory.asearch(query="TestEntity", limit=5, score_threshold=0.6) + + assert search_started.wait(timeout=2), "Timeout waiting for search started event" + assert search_completed.wait(timeout=2), "Timeout waiting for search completed event" + + assert len(events["MemoryQueryStartedEvent"]) >= 1 + assert len(events["MemoryQueryCompletedEvent"]) >= 1 + assert events["MemoryQueryStartedEvent"][-1].source_type == "entity_memory" + + +class TestAsyncContextualMemory: + """Tests for async ContextualMemory operations.""" + + @pytest.mark.asyncio + async def test_abuild_context_for_task_with_empty_query(self, mock_task): + """Test that abuild_context_for_task returns empty string for empty query.""" + mock_task.description = "" + contextual_memory = ContextualMemory( + stm=None, + ltm=None, + em=None, + exm=None, + ) + + result = await contextual_memory.abuild_context_for_task(mock_task, "") + assert result == "" + + @pytest.mark.asyncio + async def test_abuild_context_for_task_with_none_memories(self, mock_task): + """Test that abuild_context_for_task handles None memory sources.""" + contextual_memory = ContextualMemory( + stm=None, + ltm=None, + em=None, + exm=None, + ) + + result = await contextual_memory.abuild_context_for_task(mock_task, "some context") + assert result == "" + + @pytest.mark.asyncio + async def test_abuild_context_for_task_aggregates_results(self, mock_agent, mock_task): + """Test that abuild_context_for_task aggregates results from all memory sources.""" + mock_stm = MagicMock(spec=ShortTermMemory) + mock_stm.asearch = AsyncMock(return_value=[{"content": "STM insight"}]) + + mock_ltm = MagicMock(spec=LongTermMemory) + mock_ltm.asearch = AsyncMock( + return_value=[{"metadata": {"suggestions": ["LTM suggestion"]}}] + ) + + mock_em = MagicMock(spec=EntityMemory) + mock_em.asearch = AsyncMock(return_value=[{"content": "Entity info"}]) + + mock_exm = MagicMock(spec=ExternalMemory) + mock_exm.asearch = AsyncMock(return_value=[{"content": "External memory"}]) + + contextual_memory = ContextualMemory( + stm=mock_stm, + ltm=mock_ltm, + em=mock_em, + exm=mock_exm, + agent=mock_agent, + task=mock_task, + ) + + result = await contextual_memory.abuild_context_for_task(mock_task, "additional context") + + assert "Recent Insights:" in result + assert "STM insight" in result + assert "Historical Data:" in result + assert "LTM suggestion" in result + assert "Entities:" in result + assert "Entity info" in result + assert "External memories:" in result + assert "External memory" in result + + @pytest.mark.asyncio + async def test_afetch_stm_context_returns_formatted_results(self, mock_agent, mock_task): + """Test that _afetch_stm_context returns properly formatted results.""" + mock_stm = MagicMock(spec=ShortTermMemory) + mock_stm.asearch = AsyncMock( + return_value=[ + {"content": "First insight"}, + {"content": "Second insight"}, + ] + ) + + contextual_memory = ContextualMemory( + stm=mock_stm, + ltm=None, + em=None, + exm=None, + ) + + result = await contextual_memory._afetch_stm_context("test query") + + assert "Recent Insights:" in result + assert "- First insight" in result + assert "- Second insight" in result + + @pytest.mark.asyncio + async def test_afetch_ltm_context_returns_formatted_results(self, mock_agent, mock_task): + """Test that _afetch_ltm_context returns properly formatted results.""" + mock_ltm = MagicMock(spec=LongTermMemory) + mock_ltm.asearch = AsyncMock( + return_value=[ + {"metadata": {"suggestions": ["Suggestion 1", "Suggestion 2"]}}, + ] + ) + + contextual_memory = ContextualMemory( + stm=None, + ltm=mock_ltm, + em=None, + exm=None, + ) + + result = await contextual_memory._afetch_ltm_context("test task") + + assert "Historical Data:" in result + assert "- Suggestion 1" in result + assert "- Suggestion 2" in result + + @pytest.mark.asyncio + async def test_afetch_entity_context_returns_formatted_results(self, mock_agent, mock_task): + """Test that _afetch_entity_context returns properly formatted results.""" + mock_em = MagicMock(spec=EntityMemory) + mock_em.asearch = AsyncMock( + return_value=[ + {"content": "Entity A details"}, + {"content": "Entity B details"}, + ] + ) + + contextual_memory = ContextualMemory( + stm=None, + ltm=None, + em=mock_em, + exm=None, + ) + + result = await contextual_memory._afetch_entity_context("test query") + + assert "Entities:" in result + assert "- Entity A details" in result + assert "- Entity B details" in result + + @pytest.mark.asyncio + async def test_afetch_external_context_returns_formatted_results(self): + """Test that _afetch_external_context returns properly formatted results.""" + mock_exm = MagicMock(spec=ExternalMemory) + mock_exm.asearch = AsyncMock( + return_value=[ + {"content": "External data 1"}, + {"content": "External data 2"}, + ] + ) + + contextual_memory = ContextualMemory( + stm=None, + ltm=None, + em=None, + exm=mock_exm, + ) + + result = await contextual_memory._afetch_external_context("test query") + + assert "External memories:" in result + assert "- External data 1" in result + assert "- External data 2" in result + + @pytest.mark.asyncio + async def test_afetch_methods_return_empty_for_empty_results(self): + """Test that async fetch methods return empty string for no results.""" + mock_stm = MagicMock(spec=ShortTermMemory) + mock_stm.asearch = AsyncMock(return_value=[]) + + mock_ltm = MagicMock(spec=LongTermMemory) + mock_ltm.asearch = AsyncMock(return_value=[]) + + mock_em = MagicMock(spec=EntityMemory) + mock_em.asearch = AsyncMock(return_value=[]) + + mock_exm = MagicMock(spec=ExternalMemory) + mock_exm.asearch = AsyncMock(return_value=[]) + + contextual_memory = ContextualMemory( + stm=mock_stm, + ltm=mock_ltm, + em=mock_em, + exm=mock_exm, + ) + + stm_result = await contextual_memory._afetch_stm_context("query") + ltm_result = await contextual_memory._afetch_ltm_context("task") + em_result = await contextual_memory._afetch_entity_context("query") + exm_result = await contextual_memory._afetch_external_context("query") + + assert stm_result == "" + assert ltm_result is None + assert em_result == "" + assert exm_result == "" \ No newline at end of file