From c53cd2775a8639e8d5e94c4275c59c214015bb6e Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Wed, 7 May 2025 21:53:07 +0000 Subject: [PATCH] Fix Memory OSS compatibility and implement code review suggestions Co-Authored-By: Joe Moura --- src/crewai/memory/storage/mem0_storage.py | 158 +++++++++++++++++----- tests/storage/test_mem0_storage_v2.py | 133 +++++++++++++++++- 2 files changed, 258 insertions(+), 33 deletions(-) diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index 03588d2dd..b860528a6 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -12,8 +12,26 @@ class Mem0Storage(Storage): Supports Mem0 v2 API with run_id for associating memories with specific conversation sessions. By default, uses v2 API which is recommended for better context management. + + Args: + type: The type of memory storage ("user", "short_term", "long_term", "entities", "external") + crew: The crew instance this storage is associated with + config: Optional configuration dictionary that overrides crew.memory_config + + Configuration options: + version: API version to use ("v1.1" or "v2", defaults to "v2") + run_id: Optional session identifier for associating memories with specific conversations + api_key: Mem0 API key (defaults to MEM0_API_KEY environment variable) + user_id: User identifier (required for "user" memory type) + org_id: Optional organization ID for Mem0 API + project_id: Optional project ID for Mem0 API + local_mem0_config: Optional configuration for local Mem0 instance """ + SUPPORTED_VERSIONS = ["v1.1", "v2"] + + DEFAULT_VERSION = "v2" + def __init__(self, type, crew=None, config=None): super().__init__() supported_types = ["user", "short_term", "long_term", "entities", "external"] @@ -30,9 +48,10 @@ class Mem0Storage(Storage): self.memory_config = self.config or getattr(crew, "memory_config", {}) or {} config = self._get_config() - self.version = config.get("version", "v2") - + self.version = config.get("version", self.DEFAULT_VERSION) self.run_id = config.get("run_id") + + self._validate_config() # User ID is required for user memory type "user" since it's used as a unique identifier for the user. user_id = self._get_user_id() @@ -60,16 +79,70 @@ class Mem0Storage(Storage): else: self.memory = Memory() + def _validate_config(self) -> None: + """ + Validate configuration parameters. + + Raises: + ValueError: If the version is not supported + """ + if self.version not in self.SUPPORTED_VERSIONS: + raise ValueError( + f"Unsupported version: {self.version}. " + f"Please use one of: {', '.join(self.SUPPORTED_VERSIONS)}" + ) + + if self.run_id is not None and not isinstance(self.run_id, str): + raise ValueError("run_id must be a string") + + def _build_params(self, base_params: Dict[str, Any], method: str = "add") -> Dict[str, Any]: + """ + Centralize parameter building for API calls. + + Args: + base_params: Base parameters to build upon + method: The method being called ("add" or "search") + + Returns: + Dict[str, Any]: Complete parameters for API call + """ + params = base_params.copy() + + # Add version and run_id for MemoryClient + if isinstance(self.memory, MemoryClient): + params["version"] = self.version + + if self.run_id: + params["run_id"] = self.run_id + elif isinstance(self.memory, Memory) and method == "search" and "metadata" in params: + del params["metadata"] + + return params + def _sanitize_role(self, role: str) -> str: """ Sanitizes agent roles to ensure valid directory names. + + Args: + role: The role name to sanitize + + Returns: + str: Sanitized role name """ return role.replace("\n", "").replace(" ", "_").replace("/", "_") def save(self, value: Any, metadata: Dict[str, Any]) -> None: + """ + Save a memory item. + + Args: + value: The memory content to save + metadata: Additional metadata for the memory + """ user_id = self._get_user_id() agent_name = self._get_agent_name() params = None + if self.memory_type == "short_term": params = { "agent_id": agent_name, @@ -96,12 +169,7 @@ class Mem0Storage(Storage): } if params: - if isinstance(self.memory, MemoryClient): - params["version"] = self.version - - if self.run_id: - params["run_id"] = self.run_id - + params = self._build_params(params, method="add") self.memory.add(value, **params) def search( @@ -110,42 +178,61 @@ class Mem0Storage(Storage): limit: int = 3, score_threshold: float = 0.35, ) -> List[Any]: - params = {"query": query, "limit": limit} + """ + Search for memories. + + Args: + query: The search query + limit: Maximum number of results to return + score_threshold: Minimum score for results to be included + + Returns: + List[Any]: List of memory items that match the query + """ + base_params = {"query": query, "limit": limit} if user_id := self._get_user_id(): - params["user_id"] = user_id + base_params["user_id"] = user_id agent_name = self._get_agent_name() if self.memory_type == "short_term": - params["agent_id"] = agent_name - params["metadata"] = {"type": "short_term"} + base_params["agent_id"] = agent_name + base_params["metadata"] = {"type": "short_term"} elif self.memory_type == "long_term": - params["agent_id"] = agent_name - params["metadata"] = {"type": "long_term"} + base_params["agent_id"] = agent_name + base_params["metadata"] = {"type": "long_term"} elif self.memory_type == "entities": - params["agent_id"] = agent_name - params["metadata"] = {"type": "entity"} + base_params["agent_id"] = agent_name + base_params["metadata"] = {"type": "entity"} elif self.memory_type == "external": - params["agent_id"] = agent_name - params["metadata"] = {"type": "external"} + base_params["agent_id"] = agent_name + base_params["metadata"] = {"type": "external"} - # Add version and run_id for MemoryClient - if isinstance(self.memory, MemoryClient): - params["version"] = self.version - - if self.run_id: - params["run_id"] = self.run_id - # Discard the filters for Memory (OSS version) - elif isinstance(self.memory, Memory): - if "metadata" in params: - del params["metadata"] - + params = self._build_params(base_params, method="search") results = self.memory.search(**params) - return [r for r in results["results"] if r["score"] >= score_threshold] + + if isinstance(results, dict) and "results" in results: + return [r for r in results["results"] if r["score"] >= score_threshold] + elif isinstance(results, list): + return [r for r in results if r["score"] >= score_threshold] + else: + return [] def _get_user_id(self) -> str: + """ + Get the user ID from configuration. + + Returns: + str: User ID or empty string if not found + """ return self._get_config().get("user_id", "") def _get_agent_name(self) -> str: + """ + Get the agent name from the crew. + + Returns: + str: Agent name or empty string if not found + """ if not self.crew: return "" @@ -155,8 +242,17 @@ class Mem0Storage(Storage): return agents def _get_config(self) -> Dict[str, Any]: + """ + Get the configuration from either config or memory_config. + + Returns: + Dict[str, Any]: Configuration dictionary + """ return self.config or getattr(self, "memory_config", {}).get("config", {}) or {} - def reset(self): + def reset(self) -> None: + """ + Reset the memory. + """ if self.memory: self.memory.reset() diff --git a/tests/storage/test_mem0_storage_v2.py b/tests/storage/test_mem0_storage_v2.py index 9b8ccb458..b76029ad6 100644 --- a/tests/storage/test_mem0_storage_v2.py +++ b/tests/storage/test_mem0_storage_v2.py @@ -65,24 +65,54 @@ def mem0_storage_with_run_id(mock_mem0_memory_client): return mem0_storage, mock_mem0_memory_client +@pytest.fixture +def mem0_storage_with_v1_api(mock_mem0_memory_client): + """Fixture to create a Mem0Storage instance with v1.1 API configuration""" + + with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): + crew = MockCrew( + memory_config={ + "provider": "mem0", + "config": { + "user_id": "test_user", + "api_key": "ABCDEFGH", + "version": "v1.1", # Explicitly set to v1.1 + }, + } + ) + + mem0_storage = Mem0Storage(type="short_term", crew=crew) + return mem0_storage, mock_mem0_memory_client + + +@pytest.mark.v2_api def test_mem0_storage_v2_initialization(mem0_storage_with_v2_api): """Test that Mem0Storage initializes correctly with v2 API configuration""" mem0_storage, _ = mem0_storage_with_v2_api assert mem0_storage.version == "v2" - assert mem0_storage.run_id is None +@pytest.mark.v2_api def test_mem0_storage_with_run_id_initialization(mem0_storage_with_run_id): """Test that Mem0Storage initializes correctly with run_id configuration""" mem0_storage, _ = mem0_storage_with_run_id assert mem0_storage.version == "v2" - assert mem0_storage.run_id == "test-session-123" +@pytest.mark.v1_api +def test_mem0_storage_v1_initialization(mem0_storage_with_v1_api): + """Test that Mem0Storage initializes correctly with v1.1 API configuration""" + mem0_storage, _ = mem0_storage_with_v1_api + + assert mem0_storage.version == "v1.1" + assert mem0_storage.run_id is None + + +@pytest.mark.v2_api def test_save_method_with_v2_api(mem0_storage_with_v2_api): """Test save method with v2 API""" mem0_storage, mock_memory_client = mem0_storage_with_v2_api @@ -102,6 +132,7 @@ def test_save_method_with_v2_api(mem0_storage_with_v2_api): assert call_args["metadata"] == {"type": "short_term", "key": "value"} +@pytest.mark.v2_api def test_save_method_with_run_id(mem0_storage_with_run_id): """Test save method with run_id""" mem0_storage, mock_memory_client = mem0_storage_with_run_id @@ -121,6 +152,7 @@ def test_save_method_with_run_id(mem0_storage_with_run_id): assert call_args["metadata"] == {"type": "short_term", "key": "value"} +@pytest.mark.v2_api def test_search_method_with_v2_api(mem0_storage_with_v2_api): """Test search method with v2 API""" mem0_storage, mock_memory_client = mem0_storage_with_v2_api @@ -141,6 +173,7 @@ def test_search_method_with_v2_api(mem0_storage_with_v2_api): assert results[0]["content"] == "Result 1" +@pytest.mark.v2_api def test_search_method_with_run_id(mem0_storage_with_run_id): """Test search method with run_id""" mem0_storage, mock_memory_client = mem0_storage_with_run_id @@ -159,3 +192,99 @@ def test_search_method_with_run_id(mem0_storage_with_run_id): assert len(results) == 1 assert results[0]["content"] == "Result 1" + + +@pytest.mark.v2_api +def test_search_method_with_different_result_formats(mem0_storage_with_v2_api): + """Test search method with different result formats""" + mem0_storage, mock_memory_client = mem0_storage_with_v2_api + + mock_results_dict = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]} + mock_memory_client.search = MagicMock(return_value=mock_results_dict) + + results = mem0_storage.search("test query", limit=5, score_threshold=0.5) + assert len(results) == 1 + assert results[0]["content"] == "Result 1" + + mock_results_list = [{"score": 0.9, "content": "Result 3"}, {"score": 0.4, "content": "Result 4"}] + mock_memory_client.search = MagicMock(return_value=mock_results_list) + + results = mem0_storage.search("test query", limit=5, score_threshold=0.5) + assert len(results) == 1 + assert results[0]["content"] == "Result 3" + + mock_memory_client.search = MagicMock(return_value="unexpected format") + + results = mem0_storage.search("test query", limit=5, score_threshold=0.5) + assert len(results) == 0 + + +@pytest.mark.parametrize("run_id", [None, "", "test-123", "a" * 256]) +@pytest.mark.v2_api +def test_run_id_edge_cases(mock_mem0_memory_client, run_id): + """Test edge cases for run_id parameter""" + with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): + crew = MockCrew( + memory_config={ + "provider": "mem0", + "config": { + "user_id": "test_user", + "api_key": "ABCDEFGH", + "version": "v2", + "run_id": run_id, + }, + } + ) + + if run_id == "": + mem0_storage = Mem0Storage(type="short_term", crew=crew) + assert mem0_storage.run_id == "" + + mock_mem0_memory_client.add = MagicMock() + mem0_storage.save("test", {}) + assert "run_id" not in mock_mem0_memory_client.add.call_args[1] + else: + mem0_storage = Mem0Storage(type="short_term", crew=crew) + assert mem0_storage.run_id == run_id + + if run_id is not None: + mock_mem0_memory_client.add = MagicMock() + mem0_storage.save("test", {}) + assert mock_mem0_memory_client.add.call_args[1].get("run_id") == run_id + + +def test_invalid_version_handling(mock_mem0_memory_client): + """Test handling of invalid version""" + with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): + crew = MockCrew( + memory_config={ + "provider": "mem0", + "config": { + "user_id": "test_user", + "api_key": "ABCDEFGH", + "version": "invalid", + }, + } + ) + + with pytest.raises(ValueError, match="Unsupported version"): + Mem0Storage(type="short_term", crew=crew) + + +def test_invalid_run_id_type(mock_mem0_memory_client): + """Test handling of invalid run_id type""" + with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): + crew = MockCrew( + memory_config={ + "provider": "mem0", + "config": { + "user_id": "test_user", + "api_key": "ABCDEFGH", + "version": "v2", + "run_id": 123, # Not a string + }, + } + ) + + with pytest.raises(ValueError, match="run_id must be a string"): + Mem0Storage(type="short_term", crew=crew)