From 0354ad378be60171b25db22cddafdfb79b6cd5e9 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 9 Oct 2024 09:52:56 -0400 Subject: [PATCH] Add addtiional validation and tests. Also, if memory was set to True on a Crew, you were required to have a mem0 key. Fixed this direct dependency. --- src/crewai/agent.py | 17 +- src/crewai/crew.py | 25 ++- .../memory/contextual/contextual_memory.py | 4 +- .../memory/long_term/long_term_memory.py | 21 ++- src/crewai/memory/storage/mem0_storage.py | 7 +- src/crewai/memory/storage/rag_storage.py | 5 +- src/crewai/memory/user/user_memory.py | 10 +- tests/crew_test.py | 54 +++++++ tests/memory/contextual_memory_test.py | 147 ++++++++++++++++++ tests/memory/entity_memory_test.py | 119 ++++++++++++++ tests/memory/long_term_memory_test.py | 124 +++++++++++++-- 11 files changed, 490 insertions(+), 43 deletions(-) create mode 100644 tests/memory/contextual_memory_test.py create mode 100644 tests/memory/entity_memory_test.py diff --git a/src/crewai/agent.py b/src/crewai/agent.py index d15cb1629..89e92498f 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -1,18 +1,19 @@ import os from inspect import signature from typing import Any, List, Optional, Union + from pydantic import Field, InstanceOf, PrivateAttr, model_validator from crewai.agents import CacheHandler -from crewai.utilities import Converter, Prompts -from crewai.tools.agent_tools import AgentTools -from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.agents.agent_builder.base_agent import BaseAgent -from crewai.memory.contextual.contextual_memory import ContextualMemory -from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE -from crewai.utilities.training_handler import CrewTrainingHandler -from crewai.utilities.token_counter_callback import TokenCalcHandler +from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.llm import LLM +from crewai.memory.contextual.contextual_memory import ContextualMemory +from crewai.tools.agent_tools import AgentTools +from crewai.utilities import Converter, Prompts +from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE +from crewai.utilities.token_counter_callback import TokenCalcHandler +from crewai.utilities.training_handler import CrewTrainingHandler def mock_agent_ops_provider(): @@ -207,11 +208,11 @@ class Agent(BaseAgent): if self.crew and self.crew.memory: contextual_memory = ContextualMemory( - self.crew.memory_provider, self.crew._short_term_memory, self.crew._long_term_memory, self.crew._entity_memory, self.crew._user_memory, + self.crew.memory_provider, ) memory = contextual_memory.build_context_for_task(task, context) if memory.strip() != "": diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 2724f909d..105207b91 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -210,6 +210,14 @@ class Crew(BaseModel): # TODO: Improve typing return json.loads(v) if isinstance(v, Json) else v # type: ignore + @field_validator("memory_provider", mode="before") + @classmethod + def validate_memory_provider(cls, v: Optional[str]) -> Optional[str]: + """Ensure memory provider is either None or 'mem0'.""" + if v not in (None, "mem0"): + raise ValueError("Memory provider must be either None or 'mem0'.") + return v + @model_validator(mode="after") def set_private_attrs(self) -> "Crew": """Set private attributes.""" @@ -247,12 +255,18 @@ class Crew(BaseModel): embedder_config=self.embedder, ) ) - self._entity_memory = EntityMemory( - memory_provider=self.memory_provider, - crew=self, - embedder_config=self.embedder, + self._entity_memory = ( + self.entity_memory + if self.entity_memory + else EntityMemory( + memory_provider=self.memory_provider, + crew=self, + embedder_config=self.embedder, + ) + ) + self._user_memory = ( + UserMemory(crew=self) if self.memory_provider == "mem0" else None ) - self._user_memory = UserMemory(crew=self) return self @model_validator(mode="after") @@ -905,6 +919,7 @@ class Crew(BaseModel): "_short_term_memory", "_long_term_memory", "_entity_memory", + "_user_memory", "_telemetry", "agents", "tasks", diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index c359bbe97..d43938780 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -6,17 +6,17 @@ from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMem class ContextualMemory: def __init__( self, - memory_provider: str, stm: ShortTermMemory, ltm: LongTermMemory, em: EntityMemory, um: UserMemory, + memory_provider: Optional[str] = None, # Default value added ): - self.memory_provider = memory_provider self.stm = stm self.ltm = ltm self.em = em self.um = um + self.memory_provider = memory_provider def build_context_for_task(self, task, context) -> str: """ diff --git a/src/crewai/memory/long_term/long_term_memory.py b/src/crewai/memory/long_term/long_term_memory.py index ab225e406..0c77d50b3 100644 --- a/src/crewai/memory/long_term/long_term_memory.py +++ b/src/crewai/memory/long_term/long_term_memory.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, List from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem from crewai.memory.memory import Memory @@ -18,18 +18,25 @@ class LongTermMemory(Memory): storage = storage if storage else LTMSQLiteStorage() super().__init__(storage) - def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" - metadata = item.metadata - metadata.update({"agent": item.agent, "expected_output": item.expected_output}) + def save(self, item: LongTermMemoryItem) -> None: + metadata = item.metadata.copy() # Create a copy to avoid modifying the original + metadata.update( + { + "agent": item.agent, + "expected_output": item.expected_output, + "quality": item.quality, # Add quality to metadata + } + ) self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage" task_description=item.task, - score=metadata["quality"], + score=item.quality, metadata=metadata, datetime=item.datetime, ) - def search(self, task: str, latest_n: int = 3) -> Dict[str, Any]: - return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load" + def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: + results = self.storage.load(task, latest_n) + return results def reset(self) -> None: self.storage.reset() diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index bc7aca892..a64da3733 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -1,8 +1,8 @@ import os from typing import Any, Dict, List, Optional -from mem0 import MemoryClient from crewai.memory.storage.interface import Storage +from mem0 import MemoryClient class Mem0Storage(Storage): @@ -18,6 +18,9 @@ class Mem0Storage(Storage): ): os.environ["OPENAI_API_KEY"] = "fake" + if not os.getenv("MEM0_API_KEY"): + raise EnvironmentError("MEM0_API_KEY is not set.") + agents = crew.agents if crew else [] agents = [agent.role for agent in agents] agents = "_".join(agents) @@ -39,4 +42,4 @@ class Mem0Storage(Storage): if filters: params["filters"] = filters results = self.memory.search(**params) - return [r for r in results if r["score"] >= score_threshold] + return [r for r in results if float(r["score"]) >= score_threshold] diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index d6d31582d..23bd1a634 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -5,14 +5,13 @@ import os import shutil from typing import Any, Dict, List, Optional +from crewai.memory.storage.interface import Storage +from crewai.utilities.paths import db_storage_path from embedchain import App from embedchain.llm.base import BaseLlm from embedchain.models.data_type import DataType from embedchain.vectordb.chroma import InvalidDimensionException -from crewai.memory.storage.interface import Storage -from crewai.utilities.paths import db_storage_path - @contextlib.contextmanager def suppress_logging( diff --git a/src/crewai/memory/user/user_memory.py b/src/crewai/memory/user/user_memory.py index 60f4ddf06..9393c58b8 100644 --- a/src/crewai/memory/user/user_memory.py +++ b/src/crewai/memory/user/user_memory.py @@ -32,6 +32,12 @@ class UserMemory(Memory): filters: dict = {}, score_threshold: float = 0.35, ): - return super().search( - query=query, limit=limit, filters=filters, score_threshold=score_threshold + print("SEARCHING USER MEMORY", query, limit, filters, score_threshold) + result = super().search( + query=query, + limit=limit, + filters=filters, + score_threshold=score_threshold, ) + print("USER MEMORY SEARCH RESULT:", result) + return result diff --git a/tests/crew_test.py b/tests/crew_test.py index c01e84f80..43389d540 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -23,6 +23,7 @@ from crewai.types.usage_metrics import UsageMetrics from crewai.utilities import Logger from crewai.utilities.rpm_controller import RPMController from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandler +from pydantic_core import ValidationError ceo = Agent( role="CEO", @@ -173,6 +174,57 @@ def test_context_no_future_tasks(): Crew(tasks=[task1, task2, task3, task4], agents=[researcher, writer]) +def test_memory_provider_validation(): + # Create mock agents + agent1 = Agent( + role="Researcher", + goal="Conduct research on AI", + backstory="An experienced AI researcher", + allow_delegation=False, + ) + agent2 = Agent( + role="Writer", + goal="Write articles on AI", + backstory="A seasoned writer with a focus on technology", + allow_delegation=False, + ) + + # Create mock tasks + task1 = Task( + description="Research the latest trends in AI", + expected_output="A report on AI trends", + agent=agent1, + ) + task2 = Task( + description="Write an article based on the research", + expected_output="An article on AI trends", + agent=agent2, + ) + + # Test with valid memory provider values + try: + crew_with_none = Crew( + agents=[agent1, agent2], tasks=[task1, task2], memory_provider=None + ) + crew_with_mem0 = Crew( + agents=[agent1, agent2], tasks=[task1, task2], memory_provider="mem0" + ) + except ValidationError: + pytest.fail( + "Unexpected ValidationError raised for valid memory provider values" + ) + + # Test with an invalid memory provider value + with pytest.raises(ValidationError) as excinfo: + Crew( + agents=[agent1, agent2], + tasks=[task1, task2], + memory_provider="invalid_provider", + ) + + assert "Memory provider must be either None or 'mem0'." in str(excinfo.value) + + def test_crew_config_with_wrong_keys(): no_tasks_config = json.dumps( { @@ -497,6 +549,7 @@ def test_cache_hitting_between_agents(): @pytest.mark.vcr(filter_headers=["authorization"]) def test_api_calls_throttling(capsys): from unittest.mock import patch + from crewai_tools import tool @tool @@ -1105,6 +1158,7 @@ def test_dont_set_agents_step_callback_if_already_set(): @pytest.mark.vcr(filter_headers=["authorization"]) def test_crew_function_calling_llm(): from unittest.mock import patch + from crewai_tools import tool llm = "gpt-4o" diff --git a/tests/memory/contextual_memory_test.py b/tests/memory/contextual_memory_test.py new file mode 100644 index 000000000..6651a64fe --- /dev/null +++ b/tests/memory/contextual_memory_test.py @@ -0,0 +1,147 @@ +from unittest.mock import MagicMock, patch + +import pytest +from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory +from crewai.memory.contextual.contextual_memory import ContextualMemory + + +@pytest.fixture +def mock_memories(): + return { + "stm": MagicMock(spec=ShortTermMemory), + "ltm": MagicMock(spec=LongTermMemory), + "em": MagicMock(spec=EntityMemory), + "um": MagicMock(spec=UserMemory), + } + + +@pytest.fixture +def contextual_memory_mem0(mock_memories): + return ContextualMemory( + memory_provider="mem0", + stm=mock_memories["stm"], + ltm=mock_memories["ltm"], + em=mock_memories["em"], + um=mock_memories["um"], + ) + + +@pytest.fixture +def contextual_memory_other(mock_memories): + return ContextualMemory( + memory_provider="other", + stm=mock_memories["stm"], + ltm=mock_memories["ltm"], + em=mock_memories["em"], + um=mock_memories["um"], + ) + + +@pytest.fixture +def contextual_memory_none(mock_memories): + return ContextualMemory( + memory_provider=None, + stm=mock_memories["stm"], + ltm=mock_memories["ltm"], + em=mock_memories["em"], + um=mock_memories["um"], + ) + + +def test_build_context_for_task_mem0(contextual_memory_mem0, mock_memories): + task = MagicMock(description="Test task") + context = "Additional context" + + mock_memories["stm"].search.return_value = ["Recent insight"] + mock_memories["ltm"].search.return_value = [ + {"metadata": {"suggestions": ["Historical data"]}} + ] + mock_memories["em"].search.return_value = [{"memory": "Entity memory"}] + mock_memories["um"].search.return_value = [{"memory": "User memory"}] + + result = contextual_memory_mem0.build_context_for_task(task, context) + + assert "Recent Insights:" in result + assert "Historical Data:" in result + assert "Entities:" in result + assert "User memories/preferences:" in result + + +def test_build_context_for_task_other_provider(contextual_memory_other, mock_memories): + task = MagicMock(description="Test task") + context = "Additional context" + + mock_memories["stm"].search.return_value = ["Recent insight"] + mock_memories["ltm"].search.return_value = [ + {"metadata": {"suggestions": ["Historical data"]}} + ] + mock_memories["em"].search.return_value = [{"context": "Entity context"}] + mock_memories["um"].search.return_value = [{"memory": "User memory"}] + + result = contextual_memory_other.build_context_for_task(task, context) + + assert "Recent Insights:" in result + assert "Historical Data:" in result + assert "Entities:" in result + assert "User memories/preferences:" not in result + + +def test_build_context_for_task_none_provider(contextual_memory_none, mock_memories): + task = MagicMock(description="Test task") + context = "Additional context" + + mock_memories["stm"].search.return_value = ["Recent insight"] + mock_memories["ltm"].search.return_value = [ + {"metadata": {"suggestions": ["Historical data"]}} + ] + mock_memories["em"].search.return_value = [{"context": "Entity context"}] + mock_memories["um"].search.return_value = [{"memory": "User memory"}] + + result = contextual_memory_none.build_context_for_task(task, context) + + assert "Recent Insights:" in result + assert "Historical Data:" in result + assert "Entities:" in result + assert "User memories/preferences:" not in result + + +def test_fetch_entity_context_mem0(contextual_memory_mem0, mock_memories): + mock_memories["em"].search.return_value = [ + {"memory": "Entity 1"}, + {"memory": "Entity 2"}, + ] + result = contextual_memory_mem0._fetch_entity_context("query") + expected_result = "Entities:\n- Entity 1\n- Entity 2" + assert result == expected_result + + +def test_fetch_entity_context_other_provider(contextual_memory_other, mock_memories): + mock_memories["em"].search.return_value = [ + {"context": "Entity 1"}, + {"context": "Entity 2"}, + ] + result = contextual_memory_other._fetch_entity_context("query") + expected_result = "Entities:\n- Entity 1\n- Entity 2" + assert result == expected_result + + +def test_user_memories_only_for_mem0(contextual_memory_mem0, mock_memories): + mock_memories["um"].search.return_value = [{"memory": "User memory"}] + + # Test for mem0 provider + result_mem0 = contextual_memory_mem0._fetch_user_memories("query") + assert "User memories/preferences:" in result_mem0 + assert "User memory" in result_mem0 + + # Additional test to ensure user memories are included/excluded in the full context + task = MagicMock(description="Test task") + context = "Additional context" + mock_memories["stm"].search.return_value = ["Recent insight"] + mock_memories["ltm"].search.return_value = [ + {"metadata": {"suggestions": ["Historical data"]}} + ] + mock_memories["em"].search.return_value = [{"memory": "Entity memory"}] + + full_context_mem0 = contextual_memory_mem0.build_context_for_task(task, context) + assert "User memories/preferences:" in full_context_mem0 + assert "User memory" in full_context_mem0 diff --git a/tests/memory/entity_memory_test.py b/tests/memory/entity_memory_test.py new file mode 100644 index 000000000..0d45ac6ee --- /dev/null +++ b/tests/memory/entity_memory_test.py @@ -0,0 +1,119 @@ +# tests/memory/test_entity_memory.py + +from unittest.mock import MagicMock, patch + +import pytest +from crewai.memory.entity.entity_memory import EntityMemory +from crewai.memory.entity.entity_memory_item import EntityMemoryItem +from crewai.memory.storage.mem0_storage import Mem0Storage +from crewai.memory.storage.rag_storage import RAGStorage + + +@pytest.fixture +def mock_rag_storage(): + """Fixture to create a mock RAGStorage instance""" + return MagicMock(spec=RAGStorage) + + +@pytest.fixture +def mock_mem0_storage(): + """Fixture to create a mock Mem0Storage instance""" + return MagicMock(spec=Mem0Storage) + + +@pytest.fixture +def entity_memory_rag(mock_rag_storage): + """Fixture to create an EntityMemory instance with RAGStorage""" + with patch( + "crewai.memory.entity.entity_memory.RAGStorage", return_value=mock_rag_storage + ): + return EntityMemory() + + +@pytest.fixture +def entity_memory_mem0(mock_mem0_storage): + """Fixture to create an EntityMemory instance with Mem0Storage""" + with patch( + "crewai.memory.entity.entity_memory.Mem0Storage", return_value=mock_mem0_storage + ): + return EntityMemory(memory_provider="mem0") + + +def test_save_rag_storage(entity_memory_rag, mock_rag_storage): + item = EntityMemoryItem( + name="John Doe", + type="Person", + description="A software engineer", + relationships="Works at TechCorp", + ) + entity_memory_rag.save(item) + + expected_data = "John Doe(Person): A software engineer" + mock_rag_storage.save.assert_called_once_with(expected_data, item.metadata) + + +def test_save_mem0_storage(entity_memory_mem0, mock_mem0_storage): + item = EntityMemoryItem( + name="John Doe", + type="Person", + description="A software engineer", + relationships="Works at TechCorp", + ) + entity_memory_mem0.save(item) + + expected_data = """ + Remember details about the following entity: + Name: John Doe + Type: Person + Entity Description: A software engineer + """ + mock_mem0_storage.save.assert_called_once_with(expected_data, item.metadata) + + +def test_search(entity_memory_rag, mock_rag_storage): + query = "software engineer" + limit = 5 + filters = {"type": "Person"} + score_threshold = 0.7 + + entity_memory_rag.search(query, limit, filters, score_threshold) + + mock_rag_storage.search.assert_called_once_with( + query=query, limit=limit, filters=filters, score_threshold=score_threshold + ) + + +def test_reset(entity_memory_rag, mock_rag_storage): + entity_memory_rag.reset() + mock_rag_storage.reset.assert_called_once() + + +def test_reset_error(entity_memory_rag, mock_rag_storage): + mock_rag_storage.reset.side_effect = Exception("Reset error") + + with pytest.raises(Exception) as exc_info: + entity_memory_rag.reset() + + assert ( + str(exc_info.value) + == "An error occurred while resetting the entity memory: Reset error" + ) + + +@pytest.mark.parametrize("memory_provider", [None, "other"]) +def test_init_with_rag_storage(memory_provider): + with patch("crewai.memory.entity.entity_memory.RAGStorage") as mock_rag_storage: + EntityMemory(memory_provider=memory_provider) + mock_rag_storage.assert_called_once() + + +def test_init_with_mem0_storage(): + with patch("crewai.memory.entity.entity_memory.Mem0Storage") as mock_mem0_storage: + EntityMemory(memory_provider="mem0") + mock_mem0_storage.assert_called_once() + + +def test_init_with_custom_storage(): + custom_storage = MagicMock() + entity_memory = EntityMemory(storage=custom_storage) + assert entity_memory.storage == custom_storage diff --git a/tests/memory/long_term_memory_test.py b/tests/memory/long_term_memory_test.py index 3639054e3..829819f0c 100644 --- a/tests/memory/long_term_memory_test.py +++ b/tests/memory/long_term_memory_test.py @@ -1,29 +1,125 @@ -import pytest +# tests/memory/long_term_memory_test.py +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest from crewai.memory.long_term.long_term_memory import LongTermMemory from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem +from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage @pytest.fixture -def long_term_memory(): - """Fixture to create a LongTermMemory instance""" - return LongTermMemory() +def mock_storage(): + """Fixture to create a mock LTMSQLiteStorage instance""" + return MagicMock(spec=LTMSQLiteStorage) -def test_save_and_search(long_term_memory): +@pytest.fixture +def long_term_memory(mock_storage): + """Fixture to create a LongTermMemory instance with mock storage""" + return LongTermMemory(storage=mock_storage) + + +def test_save(long_term_memory, mock_storage): memory = LongTermMemoryItem( agent="test_agent", task="test_task", expected_output="test_output", - datetime="test_datetime", + datetime="2023-01-01 12:00:00", quality=0.5, - metadata={"task": "test_task", "quality": 0.5}, + metadata={"additional_info": "test_info"}, ) long_term_memory.save(memory) - find = long_term_memory.search("test_task", latest_n=5)[0] - assert find["score"] == 0.5 - assert find["datetime"] == "test_datetime" - assert find["metadata"]["agent"] == "test_agent" - assert find["metadata"]["quality"] == 0.5 - assert find["metadata"]["task"] == "test_task" - assert find["metadata"]["expected_output"] == "test_output" + + expected_metadata = { + "additional_info": "test_info", + "agent": "test_agent", + "expected_output": "test_output", + "quality": 0.5, # Include quality in expected metadata + } + mock_storage.save.assert_called_once_with( + task_description="test_task", + score=0.5, + metadata=expected_metadata, + datetime="2023-01-01 12:00:00", + ) + + +def test_search(long_term_memory, mock_storage): + mock_storage.load.return_value = [ + { + "metadata": { + "agent": "test_agent", + "expected_output": "test_output", + "task": "test_task", + }, + "datetime": "2023-01-01 12:00:00", + "score": 0.5, + } + ] + + result = long_term_memory.search("test_task", latest_n=5) + + mock_storage.load.assert_called_once_with("test_task", 5) + assert len(result) == 1 + assert result[0]["metadata"]["agent"] == "test_agent" + assert result[0]["metadata"]["expected_output"] == "test_output" + assert result[0]["metadata"]["task"] == "test_task" + assert result[0]["datetime"] == "2023-01-01 12:00:00" + assert result[0]["score"] == 0.5 + + +def test_save_with_minimal_metadata(long_term_memory, mock_storage): + memory = LongTermMemoryItem( + agent="minimal_agent", + task="minimal_task", + expected_output="minimal_output", + datetime="2023-01-01 12:00:00", + quality=0.3, + metadata={}, + ) + long_term_memory.save(memory) + + expected_metadata = { + "agent": "minimal_agent", + "expected_output": "minimal_output", + "quality": 0.3, # Include quality in expected metadata + } + mock_storage.save.assert_called_once_with( + task_description="minimal_task", + score=0.3, + metadata=expected_metadata, + datetime="2023-01-01 12:00:00", + ) + + +def test_reset(long_term_memory, mock_storage): + long_term_memory.reset() + mock_storage.reset.assert_called_once() + + +def test_search_with_no_results(long_term_memory, mock_storage): + mock_storage.load.return_value = [] + result = long_term_memory.search("nonexistent_task") + assert result == [] + + +def test_init_with_default_storage(): + with patch( + "crewai.memory.long_term.long_term_memory.LTMSQLiteStorage" + ) as mock_storage_class: + LongTermMemory() + mock_storage_class.assert_called_once() + + +def test_init_with_custom_storage(): + custom_storage = MagicMock() + memory = LongTermMemory(storage=custom_storage) + assert memory.storage == custom_storage + + +@pytest.mark.parametrize("latest_n", [1, 3, 5, 10]) +def test_search_with_different_latest_n(long_term_memory, mock_storage, latest_n): + long_term_memory.search("test_task", latest_n=latest_n) + mock_storage.load.assert_called_once_with("test_task", latest_n)