From a41145fd7ec2e5a37d4912e56f10d764a8bcc2e3 Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Thu, 24 Jul 2025 10:08:56 -0300 Subject: [PATCH] wip --- src/crewai/memory/storage/mem0_storage.py | 26 ++++++++++------ tests/storage/test_mem0_storage.py | 37 ++++++++++++++++------- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index e0e8af890..e688952f3 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -70,7 +70,7 @@ class Mem0Storage(Storage): """ Returns: dict: A filter dictionary containing AND conditions for querying data. - - Includes user_id if memory_type is 'external'. + - Includes user_id, if memory_type is 'external'. - Includes run_id if memory_type is 'short_term' and mem0_run_id is present. """ filter = { @@ -79,7 +79,14 @@ class Mem0Storage(Storage): # Add user_id condition if the memory type is external if self.memory_type == "external": - filter["AND"].append({"user_id": self.config.get("user_id", "")}) + user_id = self.config.get("user_id", "") + agent_id = self.config.get("agent_id", "") + + if user_id: + filter["AND"].append({"user_id": user_id}) + + if agent_id: + filter["AND"].append({"agent_id": agent_id}) # Add run_id condition if the memory type is short_term and a run ID is set if self.memory_type == "short_term" and self.mem0_run_id: @@ -89,7 +96,8 @@ class Mem0Storage(Storage): def save(self, value: Any, metadata: Dict[str, Any]) -> None: user_id = self.config.get("user_id", "") - assistant_message = [{"role" : "assistant","content" : value}] + agent_id = self.config.get("agent_id", "") + assistant_message = [{"role" : "assistant","content" : value}] base_metadata = { "short_term": "short_term", @@ -106,8 +114,8 @@ class Mem0Storage(Storage): if self.memory_type == "external": params["user_id"] = user_id + params["agent_id"] = agent_id - if params: # MemoryClient-specific overrides if isinstance(self.memory, MemoryClient): @@ -123,12 +131,12 @@ class Mem0Storage(Storage): def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]: params = { - "query": query, - "limit": limit, + "query": query, + "limit": limit, "version": "v2", "output_format": "v1.1" } - + if user_id := self.config.get("user_id", ""): params["user_id"] = user_id @@ -138,7 +146,7 @@ class Mem0Storage(Storage): "entities": {"type": "entity"}, "external": {"type": "external"}, } - + if self.memory_type in memory_type_map: params["metadata"] = memory_type_map[self.memory_type] if self.memory_type == "short_term": @@ -155,7 +163,7 @@ class Mem0Storage(Storage): results = self.memory.search(**params) return [r for r in results["results"]] - + def reset(self): if self.memory: self.memory.reset() diff --git a/tests/storage/test_mem0_storage.py b/tests/storage/test_mem0_storage.py index 6c4cf3c6e..a4eace2e5 100644 --- a/tests/storage/test_mem0_storage.py +++ b/tests/storage/test_mem0_storage.py @@ -191,13 +191,13 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config): """Test save method for different memory types""" mem0_storage, _, _ = mem0_storage_with_mocked_config mem0_storage.memory.add = MagicMock() - + # Test short_term memory type (already set in fixture) test_value = "This is a test memory" test_metadata = {"key": "value"} - + mem0_storage.save(test_value, test_metadata) - + mem0_storage.memory.add.assert_called_once_with( [{'role': 'assistant' , 'content': test_value}], infer=True, @@ -209,13 +209,13 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co """Test save method for different memory types""" mem0_storage = mem0_storage_with_memory_client_using_config_from_crew mem0_storage.memory.add = MagicMock() - + # Test short_term memory type (already set in fixture) test_value = "This is a test memory" test_metadata = {"key": "value"} - + mem0_storage.save(test_value, test_metadata) - + mem0_storage.memory.add.assert_called_once_with( [{'role': 'assistant' , 'content': test_value}], infer=True, @@ -237,10 +237,10 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config): results = mem0_storage.search("test query", limit=5, score_threshold=0.5) mem0_storage.memory.search.assert_called_once_with( - query="test query", - limit=5, + query="test query", + limit=5, user_id="test_user", - filters={'AND': [{'run_id': 'my_run_id'}]}, + filters={'AND': [{'run_id': 'my_run_id'}]}, threshold=0.5 ) @@ -257,8 +257,8 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_ results = mem0_storage.search("test query", limit=5, score_threshold=0.5) mem0_storage.memory.search.assert_called_once_with( - query="test query", - limit=5, + query="test query", + limit=5, metadata={"type": "short_term"}, user_id="test_user", version='v2', @@ -270,3 +270,18 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_ assert len(results) == 2 assert results[0]["content"] == "Result 1" + +def test_save_memory_using_agent_entity(): + config = { + "agent_id": "agent-123", + } + mem0_storage = Mem0Storage(type="short_term", config=config) + + mem0_storage.save("test memory", {"key": "value"}) + + mem0_storage.memory.add.assert_called_once_with( + [{'role': 'assistant' , 'content': 'test memory'}], + infer=True, + metadata={"type": "short_term", "key": "value"}, + agent_id="agent-123", + )