diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index 8e6e1f65d..521134cbc 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -1,7 +1,8 @@ import os from typing import Any, Dict, List - +from collections import defaultdict from mem0 import Memory, MemoryClient +from crewai.utilities.chromadb import sanitize_collection_name from crewai.memory.storage.interface import Storage @@ -70,26 +71,32 @@ class Mem0Storage(Storage): """ Returns: dict: A filter dictionary containing AND conditions for querying data. - - Includes user_id if memory_type is 'external'. + - Includes user_id and agent_id if both are present. + - Includes user_id if only user_id is present. + - Includes agent_id if only agent_id is present. - Includes run_id if memory_type is 'short_term' and mem0_run_id is present. """ - filter = { - "AND": [] - } + filter = defaultdict(list) - # Add user_id condition if the memory type is external - if self.memory_type == "external": - filter["AND"].append({"user_id": self.config.get("user_id", "")}) - - # Add run_id condition if the memory type is short_term and a run ID is set if self.memory_type == "short_term" and self.mem0_run_id: filter["AND"].append({"run_id": self.mem0_run_id}) + else: + user_id = self.config.get("user_id", "") + agent_id = self.config.get("agent_id", "") + + if user_id and agent_id: + filter["OR"].append({"user_id": user_id}) + filter["OR"].append({"agent_id": agent_id}) + elif user_id: + filter["AND"].append({"user_id": user_id}) + elif agent_id: + filter["AND"].append({"agent_id": agent_id}) return filter def save(self, value: Any, metadata: Dict[str, Any]) -> None: user_id = self.config.get("user_id", "") - assistant_message = [{"role" : "assistant","content" : value}] + assistant_message = [{"role" : "assistant","content" : value}] base_metadata = { "short_term": "short_term", @@ -104,31 +111,32 @@ class Mem0Storage(Storage): "infer": self.infer } - if self.memory_type == "external": + # MemoryClient-specific overrides + if isinstance(self.memory, MemoryClient): + params["includes"] = self.includes + params["excludes"] = self.excludes + params["output_format"] = "v1.1" + params["version"] = "v2" + + if self.memory_type == "short_term" and self.mem0_run_id: + params["run_id"] = self.mem0_run_id + + if user_id: params["user_id"] = user_id - - if params: - # MemoryClient-specific overrides - if isinstance(self.memory, MemoryClient): - params["includes"] = self.includes - params["excludes"] = self.excludes - params["output_format"] = "v1.1" - params["version"]="v2" + if agent_id := self.config.get("agent_id", self._get_agent_name()): + params["agent_id"] = agent_id - if self.memory_type == "short_term": - params["run_id"] = self.mem0_run_id - - self.memory.add(assistant_message, **params) + self.memory.add(assistant_message, **params) def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]: params = { - "query": query, - "limit": limit, + "query": query, + "limit": limit, "version": "v2", "output_format": "v1.1" } - + if user_id := self.config.get("user_id", ""): params["user_id"] = user_id @@ -138,7 +146,7 @@ class Mem0Storage(Storage): "entities": {"type": "entity"}, "external": {"type": "external"}, } - + if self.memory_type in memory_type_map: params["metadata"] = memory_type_map[self.memory_type] if self.memory_type == "short_term": @@ -151,11 +159,28 @@ class Mem0Storage(Storage): params['threshold'] = score_threshold if isinstance(self.memory, Memory): - del params["metadata"], params["version"], params["run_id"], params['output_format'] + del params["metadata"], params["version"], params['output_format'] + if params.get("run_id"): + del params["run_id"] results = self.memory.search(**params) return [r for r in results["results"]] - + def reset(self): if self.memory: self.memory.reset() + + def _sanitize_role(self, role: str) -> str: + """ + Sanitizes agent roles to ensure valid directory names. + """ + return role.replace("\n", "").replace(" ", "_").replace("/", "_") + + def _get_agent_name(self) -> str: + if not self.crew: + return "" + + agents = self.crew.agents + agents = [self._sanitize_role(agent.role) for agent in agents] + agents = "_".join(agents) + return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0) diff --git a/tests/storage/test_mem0_storage.py b/tests/storage/test_mem0_storage.py index 76de5f63d..cd491c86c 100644 --- a/tests/storage/test_mem0_storage.py +++ b/tests/storage/test_mem0_storage.py @@ -191,17 +191,39 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config): """Test save method for different memory types""" mem0_storage, _, _ = mem0_storage_with_mocked_config mem0_storage.memory.add = MagicMock() - + # Test short_term memory type (already set in fixture) test_value = "This is a test memory" test_metadata = {"key": "value"} - + mem0_storage.save(test_value, test_metadata) - + mem0_storage.memory.add.assert_called_once_with( - [{'role': 'assistant' , 'content': test_value}], + [{"role": "assistant" , "content": test_value}], infer=True, metadata={"type": "short_term", "key": "value"}, + run_id="my_run_id", + user_id="test_user", + agent_id='Test_Agent' + ) + +def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config): + mem0_storage, _, _ = mem0_storage_with_mocked_config + mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")] + mem0_storage.memory.add = MagicMock() + + test_value = "This is a test memory" + test_metadata = {"key": "value"} + + mem0_storage.save(test_value, test_metadata) + + mem0_storage.memory.add.assert_called_once_with( + [{"role": "assistant" , "content": test_value}], + infer=True, + metadata={"type": "short_term", "key": "value"}, + run_id="my_run_id", + user_id="test_user", + agent_id='Test_Agent_Test_Agent_2_Test_Agent_3' ) @@ -209,13 +231,13 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co """Test save method for different memory types""" mem0_storage = mem0_storage_with_memory_client_using_config_from_crew mem0_storage.memory.add = MagicMock() - + # Test short_term memory type (already set in fixture) test_value = "This is a test memory" test_metadata = {"key": "value"} - + mem0_storage.save(test_value, test_metadata) - + mem0_storage.memory.add.assert_called_once_with( [{'role': 'assistant' , 'content': test_value}], infer=True, @@ -224,7 +246,9 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co run_id="my_run_id", includes="include1", excludes="exclude1", - output_format='v1.1' + output_format='v1.1', + user_id='test_user', + agent_id='Test_Agent' ) @@ -237,10 +261,10 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config): results = mem0_storage.search("test query", limit=5, score_threshold=0.5) mem0_storage.memory.search.assert_called_once_with( - query="test query", - limit=5, + query="test query", + limit=5, user_id="test_user", - filters={'AND': [{'run_id': 'my_run_id'}]}, + filters={'AND': [{'run_id': 'my_run_id'}]}, threshold=0.5 ) @@ -257,8 +281,8 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_ results = mem0_storage.search("test query", limit=5, score_threshold=0.5) mem0_storage.memory.search.assert_called_once_with( - query="test query", - limit=5, + query="test query", + limit=5, metadata={"type": "short_term"}, user_id="test_user", version='v2', @@ -286,4 +310,56 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client): ) mem0_storage = Mem0Storage(type="short_term", crew=crew) - assert mem0_storage.infer is True \ No newline at end of file + assert mem0_storage.infer is True + +def test_save_memory_using_agent_entity(mock_mem0_memory_client): + config = { + "agent_id": "agent-123", + } + + mock_memory = MagicMock(spec=Memory) + with patch.object(Memory, "__new__", return_value=mock_memory): + mem0_storage = Mem0Storage(type="external", config=config) + mem0_storage.save("test memory", {"key": "value"}) + mem0_storage.memory.add.assert_called_once_with( + [{'role': 'assistant' , 'content': 'test memory'}], + infer=True, + metadata={"type": "external", "key": "value"}, + agent_id="agent-123", + ) + +def test_search_method_with_agent_entity(): + mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123"}) + mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]} + mem0_storage.memory.search = MagicMock(return_value=mock_results) + + results = mem0_storage.search("test query", limit=5, score_threshold=0.5) + + mem0_storage.memory.search.assert_called_once_with( + query="test query", + limit=5, + filters={"AND": [{"agent_id": "agent-123"}]}, + threshold=0.5, + ) + + assert len(results) == 2 + assert results[0]["content"] == "Result 1" + + +def test_search_method_with_agent_id_and_user_id(): + mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"}) + mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]} + mem0_storage.memory.search = MagicMock(return_value=mock_results) + + results = mem0_storage.search("test query", limit=5, score_threshold=0.5) + + mem0_storage.memory.search.assert_called_once_with( + query="test query", + limit=5, + user_id='user-123', + filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]}, + threshold=0.5, + ) + + assert len(results) == 2 + assert results[0]["content"] == "Result 1"