fix: support to add memories to Mem0 with agent_id (#3217)

* fix: support to add memories to Mem0 with agent_id

* feat: removing memory_type checkings from Mem0Storage

* feat: ensure agent_id is always present while saving memory into Mem0

* fix: use OR operator when querying Mem0 memories with both user_id and agent_id
This commit is contained in:
Lucas Gomide
2025-07-30 12:56:46 -03:00
committed by GitHub
parent 498e8dc6e8
commit 34c3075fdb
2 changed files with 145 additions and 44 deletions

View File

@@ -1,7 +1,8 @@
import os
from typing import Any, Dict, List
from collections import defaultdict
from mem0 import Memory, MemoryClient
from crewai.utilities.chromadb import sanitize_collection_name
from crewai.memory.storage.interface import Storage
@@ -70,20 +71,26 @@ class Mem0Storage(Storage):
"""
Returns:
dict: A filter dictionary containing AND conditions for querying data.
- Includes user_id if memory_type is 'external'.
- Includes user_id and agent_id if both are present.
- Includes user_id if only user_id is present.
- Includes agent_id if only agent_id is present.
- Includes run_id if memory_type is 'short_term' and mem0_run_id is present.
"""
filter = {
"AND": []
}
filter = defaultdict(list)
# Add user_id condition if the memory type is external
if self.memory_type == "external":
filter["AND"].append({"user_id": self.config.get("user_id", "")})
# Add run_id condition if the memory type is short_term and a run ID is set
if self.memory_type == "short_term" and self.mem0_run_id:
filter["AND"].append({"run_id": self.mem0_run_id})
else:
user_id = self.config.get("user_id", "")
agent_id = self.config.get("agent_id", "")
if user_id and agent_id:
filter["OR"].append({"user_id": user_id})
filter["OR"].append({"agent_id": agent_id})
elif user_id:
filter["AND"].append({"user_id": user_id})
elif agent_id:
filter["AND"].append({"agent_id": agent_id})
return filter
@@ -104,21 +111,22 @@ class Mem0Storage(Storage):
"infer": self.infer
}
if self.memory_type == "external":
params["user_id"] = user_id
if params:
# MemoryClient-specific overrides
if isinstance(self.memory, MemoryClient):
params["includes"] = self.includes
params["excludes"] = self.excludes
params["output_format"] = "v1.1"
params["version"]="v2"
params["version"] = "v2"
if self.memory_type == "short_term":
if self.memory_type == "short_term" and self.mem0_run_id:
params["run_id"] = self.mem0_run_id
if user_id:
params["user_id"] = user_id
if agent_id := self.config.get("agent_id", self._get_agent_name()):
params["agent_id"] = agent_id
self.memory.add(assistant_message, **params)
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
@@ -151,7 +159,9 @@ class Mem0Storage(Storage):
params['threshold'] = score_threshold
if isinstance(self.memory, Memory):
del params["metadata"], params["version"], params["run_id"], params['output_format']
del params["metadata"], params["version"], params['output_format']
if params.get("run_id"):
del params["run_id"]
results = self.memory.search(**params)
return [r for r in results["results"]]
@@ -159,3 +169,18 @@ class Mem0Storage(Storage):
def reset(self):
if self.memory:
self.memory.reset()
def _sanitize_role(self, role: str) -> str:
"""
Sanitizes agent roles to ensure valid directory names.
"""
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _get_agent_name(self) -> str:
if not self.crew:
return ""
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)

View File

@@ -199,9 +199,31 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': test_value}],
[{"role": "assistant" , "content": test_value}],
infer=True,
metadata={"type": "short_term", "key": "value"},
run_id="my_run_id",
user_id="test_user",
agent_id='Test_Agent'
)
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
mem0_storage, _, _ = mem0_storage_with_mocked_config
mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")]
mem0_storage.memory.add = MagicMock()
test_value = "This is a test memory"
test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{"role": "assistant" , "content": test_value}],
infer=True,
metadata={"type": "short_term", "key": "value"},
run_id="my_run_id",
user_id="test_user",
agent_id='Test_Agent_Test_Agent_2_Test_Agent_3'
)
@@ -224,7 +246,9 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co
run_id="my_run_id",
includes="include1",
excludes="exclude1",
output_format='v1.1'
output_format='v1.1',
user_id='test_user',
agent_id='Test_Agent'
)
@@ -287,3 +311,55 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
mem0_storage = Mem0Storage(type="short_term", crew=crew)
assert mem0_storage.infer is True
def test_save_memory_using_agent_entity(mock_mem0_memory_client):
config = {
"agent_id": "agent-123",
}
mock_memory = MagicMock(spec=Memory)
with patch.object(Memory, "__new__", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config=config)
mem0_storage.save("test memory", {"key": "value"})
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': 'test memory'}],
infer=True,
metadata={"type": "external", "key": "value"},
agent_id="agent-123",
)
def test_search_method_with_agent_entity():
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123"})
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
filters={"AND": [{"agent_id": "agent-123"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["content"] == "Result 1"
def test_search_method_with_agent_id_and_user_id():
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"})
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
user_id='user-123',
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["content"] == "Result 1"