diff --git a/src/crewai/memory/external/external_memory.py b/src/crewai/memory/external/external_memory.py index a76855d2a..2e9bc5070 100644 --- a/src/crewai/memory/external/external_memory.py +++ b/src/crewai/memory/external/external_memory.py @@ -2,17 +2,12 @@ from typing import Any, Dict, Optional, Self from crewai.memory.external.external_memory_item import ExternalMemoryItem from crewai.memory.memory import Memory +from crewai.memory.storage.interface import Storage class ExternalMemory(Memory): - def __init__(self, crew=None, embedder_config=None, storage=None): - storage = ( - storage - if storage - else self.create_storage(crew=crew, embedder_config=embedder_config) - ) - - super().__init__(storage=storage, embedder_config=embedder_config, crew=crew) + def __init__(self, storage: Optional[Storage] = None, **data: Any): + super().__init__(storage=storage, **data) @staticmethod def _configure_mem0(crew, config) -> "Mem0Storage": @@ -56,6 +51,8 @@ class ExternalMemory(Memory): def set_crew(self, crew: Any) -> Self: super().set_crew(crew) - self.storage = self.create_storage(crew, self.embedder_config) + + if not self.storage: + self.storage = self.create_storage(crew, self.embedder_config) return self diff --git a/tests/memory/external/test_external_memory.py b/tests/memory/external/test_external_memory.py index 3ee706335..8cbc34ef9 100644 --- a/tests/memory/external/test_external_memory.py +++ b/tests/memory/external/test_external_memory.py @@ -7,6 +7,7 @@ from crewai.agent import Agent from crewai.crew import Crew, Process from crewai.memory.external.external_memory import ExternalMemory from crewai.memory.external.external_memory_item import ExternalMemoryItem +from crewai.memory.storage.interface import Storage from crewai.task import Task @@ -142,3 +143,38 @@ def test_crew_external_memory_save(mem_method, crew_with_external_memory): ) as mock_method: crew_with_external_memory.kickoff() assert mock_method.call_count > 0 + + +def test_external_memory_custom_storage(crew_with_external_memory): + class CustomStorage(Storage): + def __init__(self): + self.memories = [] + + def save(self, value, metadata=None, agent=None): + self.memories.append({"value": value, "metadata": metadata, "agent": agent}) + + def search(self, query, limit=10, score_threshold=0.5): + return self.memories + + def reset(self): + self.memories = [] + + custom_storage = CustomStorage() + external_memory = ExternalMemory(storage=custom_storage) + + # by ensuring the crew is set, we can test that the storage is used + external_memory.set_crew(crew_with_external_memory) + + test_value = "test value" + test_metadata = {"source": "test"} + test_agent = "test_agent" + external_memory.save(value=test_value, metadata=test_metadata, agent=test_agent) + + results = external_memory.search("test") + assert len(results) == 1 + assert results[0]["value"] == test_value + assert results[0]["metadata"] == test_metadata | {"agent": test_agent} + + external_memory.reset() + results = external_memory.search("test") + assert len(results) == 0