diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index 8e6e1f65d..f5176719a 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -69,17 +69,25 @@ class Mem0Storage(Storage): def _create_filter_for_search(self): """ Returns: - dict: A filter dictionary containing AND conditions for querying data. - - Includes user_id if memory_type is 'external'. - - Includes run_id if memory_type is 'short_term' and mem0_run_id is present. + dict: A filter dictionary containing conditions for querying data. + - Uses OR logic when both user_id and agent_id are present + - Uses AND logic for other conditions like run_id """ - filter = { - "AND": [] - } - - # Add user_id condition if the memory type is external - if self.memory_type == "external": - filter["AND"].append({"user_id": self.config.get("user_id", "")}) + user_id = self.config.get("user_id", "") + agent_id = self.config.get("agent_id", "") + + filter = {"AND": []} + + id_conditions = [] + if user_id: + id_conditions.append({"user_id": user_id}) + if agent_id: + id_conditions.append({"agent_id": agent_id}) + + if len(id_conditions) > 1: + filter["AND"].append({"OR": id_conditions}) + elif len(id_conditions) == 1: + filter["AND"].append(id_conditions[0]) # Add run_id condition if the memory type is short_term and a run ID is set if self.memory_type == "short_term" and self.mem0_run_id: @@ -89,6 +97,7 @@ class Mem0Storage(Storage): def save(self, value: Any, metadata: Dict[str, Any]) -> None: user_id = self.config.get("user_id", "") + agent_id = metadata.get("agent", "") assistant_message = [{"role" : "assistant","content" : value}] base_metadata = { @@ -104,8 +113,11 @@ class Mem0Storage(Storage): "infer": self.infer } - if self.memory_type == "external": + if user_id: params["user_id"] = user_id + + if agent_id: + params["agent_id"] = agent_id if params: @@ -129,8 +141,14 @@ class Mem0Storage(Storage): "output_format": "v1.1" } - if user_id := self.config.get("user_id", ""): + user_id = self.config.get("user_id", "") + agent_id = self.config.get("agent_id", "") + + if user_id: params["user_id"] = user_id + + if agent_id: + params["agent_id"] = agent_id memory_type_map = { "short_term": {"type": "short_term"}, @@ -156,6 +174,12 @@ class Mem0Storage(Storage): results = self.memory.search(**params) return [r for r in results["results"]] + def set_agent_id(self, agent_id: str) -> None: + """Set the agent_id for this memory storage instance.""" + if not self.config: + self.config = {} + self.config["agent_id"] = agent_id + def reset(self): if self.memory: self.memory.reset() diff --git a/tests/storage/test_mem0_storage.py b/tests/storage/test_mem0_storage.py index 76de5f63d..ac7e83a4a 100644 --- a/tests/storage/test_mem0_storage.py +++ b/tests/storage/test_mem0_storage.py @@ -202,6 +202,7 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config): [{'role': 'assistant' , 'content': test_value}], infer=True, metadata={"type": "short_term", "key": "value"}, + user_id="test_user" ) @@ -220,6 +221,7 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co [{'role': 'assistant' , 'content': test_value}], infer=True, metadata={"type": "short_term", "key": "value"}, + user_id="test_user", version="v2", run_id="my_run_id", includes="include1", @@ -237,10 +239,10 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config): results = mem0_storage.search("test query", limit=5, score_threshold=0.5) mem0_storage.memory.search.assert_called_once_with( - query="test query", - limit=5, + query="test query", + limit=5, user_id="test_user", - filters={'AND': [{'run_id': 'my_run_id'}]}, + filters={'AND': [{'user_id': 'test_user'}, {'run_id': 'my_run_id'}]}, threshold=0.5 ) @@ -257,14 +259,14 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_ results = mem0_storage.search("test query", limit=5, score_threshold=0.5) mem0_storage.memory.search.assert_called_once_with( - query="test query", - limit=5, + query="test query", + limit=5, metadata={"type": "short_term"}, user_id="test_user", version='v2', run_id="my_run_id", output_format='v1.1', - filters={'AND': [{'run_id': 'my_run_id'}]}, + filters={'AND': [{'user_id': 'test_user'}, {'run_id': 'my_run_id'}]}, threshold=0.5 ) @@ -286,4 +288,111 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client): ) mem0_storage = Mem0Storage(type="short_term", crew=crew) - assert mem0_storage.infer is True \ No newline at end of file + assert mem0_storage.infer is True + + +def test_save_with_agent_id_from_metadata(mem0_storage_with_memory_client_using_config_from_crew): + """Test that agent_id is extracted from metadata and used in save operation""" + mem0_storage = mem0_storage_with_memory_client_using_config_from_crew + mem0_storage.memory.add = MagicMock() + + test_value = "This is a test memory from agent" + test_metadata = {"agent": "test_agent_123", "key": "value"} + + mem0_storage.save(test_value, test_metadata) + + call_args = mem0_storage.memory.add.call_args + assert "agent_id" in call_args[1] + assert call_args[1]["agent_id"] == "test_agent_123" + + +def test_search_with_agent_id_in_config(mem0_storage_with_memory_client_using_config_from_crew): + """Test search method includes agent_id when present in config""" + mem0_storage = mem0_storage_with_memory_client_using_config_from_crew + mem0_storage.config["agent_id"] = "test_agent_456" + + mock_results = {"results": [{"score": 0.9, "content": "Result 1"}]} + mem0_storage.memory.search = MagicMock(return_value=mock_results) + + results = mem0_storage.search("test query") + + call_args = mem0_storage.memory.search.call_args[1] + assert "agent_id" in call_args + assert call_args["agent_id"] == "test_agent_456" + + +def test_filter_uses_or_logic_with_both_user_and_agent_id(mock_mem0_memory_client): + """Test that filter uses OR logic when both user_id and agent_id are present""" + crew = MockCrew(memory_config={"provider": "mem0", "config": {"user_id": "test_user", "api_key": "ABCDEFGH"}}) + + with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): + mem0_storage = Mem0Storage(type="external", crew=crew) + mem0_storage.config["user_id"] = "test_user" + mem0_storage.config["agent_id"] = "test_agent" + + filter_result = mem0_storage._create_filter_for_search() + + assert "AND" in filter_result + assert len(filter_result["AND"]) == 1 + assert "OR" in filter_result["AND"][0] + or_conditions = filter_result["AND"][0]["OR"] + assert {"user_id": "test_user"} in or_conditions + assert {"agent_id": "test_agent"} in or_conditions + + +def test_filter_uses_single_condition_with_only_agent_id(mock_mem0_memory_client): + """Test that filter uses single condition when only agent_id is present""" + crew = MockCrew(memory_config={"provider": "mem0", "config": {"api_key": "ABCDEFGH"}}) + + with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): + mem0_storage = Mem0Storage(type="external", crew=crew) + mem0_storage.config["agent_id"] = "test_agent" + + filter_result = mem0_storage._create_filter_for_search() + + assert "AND" in filter_result + assert {"agent_id": "test_agent"} in filter_result["AND"] + + +def test_set_agent_id_method(mock_mem0_memory_client): + """Test the set_agent_id method""" + crew = MockCrew(memory_config={"provider": "mem0", "config": {"api_key": "ABCDEFGH"}}) + + with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client): + mem0_storage = Mem0Storage(type="external", crew=crew) + mem0_storage.set_agent_id("new_agent_123") + + assert mem0_storage.config["agent_id"] == "new_agent_123" + + +def test_save_with_both_user_id_and_agent_id(mem0_storage_with_memory_client_using_config_from_crew): + """Test save method with both user_id and agent_id""" + mem0_storage = mem0_storage_with_memory_client_using_config_from_crew + mem0_storage.memory.add = MagicMock() + + test_value = "This is a test memory" + test_metadata = {"agent": "test_agent_789", "key": "value"} + + mem0_storage.save(test_value, test_metadata) + + call_args = mem0_storage.memory.add.call_args[1] + assert "user_id" in call_args + assert call_args["user_id"] == "test_user" + assert "agent_id" in call_args + assert call_args["agent_id"] == "test_agent_789" + + +def test_save_without_agent_id_in_metadata(mem0_storage_with_memory_client_using_config_from_crew): + """Test save method when no agent_id is in metadata""" + mem0_storage = mem0_storage_with_memory_client_using_config_from_crew + mem0_storage.memory.add = MagicMock() + + test_value = "This is a test memory" + test_metadata = {"key": "value"} + + mem0_storage.save(test_value, test_metadata) + + call_args = mem0_storage.memory.add.call_args[1] + assert "user_id" in call_args + assert call_args["user_id"] == "test_user" + assert "agent_id" not in call_args