fix: support to add memories to Mem0 with agent_id (#3217)

* fix: support to add memories to Mem0 with agent_id

* feat: removing memory_type checkings from Mem0Storage

* feat: ensure agent_id is always present while saving memory into Mem0

* fix: use OR operator when querying Mem0 memories with both user_id and agent_id
This commit is contained in:
Lucas Gomide
2025-07-30 12:56:46 -03:00
committed by GitHub
parent 498e8dc6e8
commit 34c3075fdb
2 changed files with 145 additions and 44 deletions

View File

@@ -1,7 +1,8 @@
import os import os
from typing import Any, Dict, List from typing import Any, Dict, List
from collections import defaultdict
from mem0 import Memory, MemoryClient from mem0 import Memory, MemoryClient
from crewai.utilities.chromadb import sanitize_collection_name
from crewai.memory.storage.interface import Storage from crewai.memory.storage.interface import Storage
@@ -70,26 +71,32 @@ class Mem0Storage(Storage):
""" """
Returns: Returns:
dict: A filter dictionary containing AND conditions for querying data. dict: A filter dictionary containing AND conditions for querying data.
- Includes user_id if memory_type is 'external'. - Includes user_id and agent_id if both are present.
- Includes user_id if only user_id is present.
- Includes agent_id if only agent_id is present.
- Includes run_id if memory_type is 'short_term' and mem0_run_id is present. - Includes run_id if memory_type is 'short_term' and mem0_run_id is present.
""" """
filter = { filter = defaultdict(list)
"AND": []
}
# Add user_id condition if the memory type is external
if self.memory_type == "external":
filter["AND"].append({"user_id": self.config.get("user_id", "")})
# Add run_id condition if the memory type is short_term and a run ID is set
if self.memory_type == "short_term" and self.mem0_run_id: if self.memory_type == "short_term" and self.mem0_run_id:
filter["AND"].append({"run_id": self.mem0_run_id}) filter["AND"].append({"run_id": self.mem0_run_id})
else:
user_id = self.config.get("user_id", "")
agent_id = self.config.get("agent_id", "")
if user_id and agent_id:
filter["OR"].append({"user_id": user_id})
filter["OR"].append({"agent_id": agent_id})
elif user_id:
filter["AND"].append({"user_id": user_id})
elif agent_id:
filter["AND"].append({"agent_id": agent_id})
return filter return filter
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: Dict[str, Any]) -> None:
user_id = self.config.get("user_id", "") user_id = self.config.get("user_id", "")
assistant_message = [{"role" : "assistant","content" : value}] assistant_message = [{"role" : "assistant","content" : value}]
base_metadata = { base_metadata = {
"short_term": "short_term", "short_term": "short_term",
@@ -104,31 +111,32 @@ class Mem0Storage(Storage):
"infer": self.infer "infer": self.infer
} }
if self.memory_type == "external": # MemoryClient-specific overrides
if isinstance(self.memory, MemoryClient):
params["includes"] = self.includes
params["excludes"] = self.excludes
params["output_format"] = "v1.1"
params["version"] = "v2"
if self.memory_type == "short_term" and self.mem0_run_id:
params["run_id"] = self.mem0_run_id
if user_id:
params["user_id"] = user_id params["user_id"] = user_id
if agent_id := self.config.get("agent_id", self._get_agent_name()):
if params: params["agent_id"] = agent_id
# MemoryClient-specific overrides
if isinstance(self.memory, MemoryClient):
params["includes"] = self.includes
params["excludes"] = self.excludes
params["output_format"] = "v1.1"
params["version"]="v2"
if self.memory_type == "short_term": self.memory.add(assistant_message, **params)
params["run_id"] = self.mem0_run_id
self.memory.add(assistant_message, **params)
def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]: def search(self,query: str,limit: int = 3,score_threshold: float = 0.35) -> List[Any]:
params = { params = {
"query": query, "query": query,
"limit": limit, "limit": limit,
"version": "v2", "version": "v2",
"output_format": "v1.1" "output_format": "v1.1"
} }
if user_id := self.config.get("user_id", ""): if user_id := self.config.get("user_id", ""):
params["user_id"] = user_id params["user_id"] = user_id
@@ -138,7 +146,7 @@ class Mem0Storage(Storage):
"entities": {"type": "entity"}, "entities": {"type": "entity"},
"external": {"type": "external"}, "external": {"type": "external"},
} }
if self.memory_type in memory_type_map: if self.memory_type in memory_type_map:
params["metadata"] = memory_type_map[self.memory_type] params["metadata"] = memory_type_map[self.memory_type]
if self.memory_type == "short_term": if self.memory_type == "short_term":
@@ -151,11 +159,28 @@ class Mem0Storage(Storage):
params['threshold'] = score_threshold params['threshold'] = score_threshold
if isinstance(self.memory, Memory): if isinstance(self.memory, Memory):
del params["metadata"], params["version"], params["run_id"], params['output_format'] del params["metadata"], params["version"], params['output_format']
if params.get("run_id"):
del params["run_id"]
results = self.memory.search(**params) results = self.memory.search(**params)
return [r for r in results["results"]] return [r for r in results["results"]]
def reset(self): def reset(self):
if self.memory: if self.memory:
self.memory.reset() self.memory.reset()
def _sanitize_role(self, role: str) -> str:
"""
Sanitizes agent roles to ensure valid directory names.
"""
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _get_agent_name(self) -> str:
if not self.crew:
return ""
agents = self.crew.agents
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
return sanitize_collection_name(name=agents, max_collection_length=MAX_AGENT_ID_LENGTH_MEM0)

View File

@@ -191,17 +191,39 @@ def test_save_method_with_memory_oss(mem0_storage_with_mocked_config):
"""Test save method for different memory types""" """Test save method for different memory types"""
mem0_storage, _, _ = mem0_storage_with_mocked_config mem0_storage, _, _ = mem0_storage_with_mocked_config
mem0_storage.memory.add = MagicMock() mem0_storage.memory.add = MagicMock()
# Test short_term memory type (already set in fixture) # Test short_term memory type (already set in fixture)
test_value = "This is a test memory" test_value = "This is a test memory"
test_metadata = {"key": "value"} test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata) mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with( mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': test_value}], [{"role": "assistant" , "content": test_value}],
infer=True, infer=True,
metadata={"type": "short_term", "key": "value"}, metadata={"type": "short_term", "key": "value"},
run_id="my_run_id",
user_id="test_user",
agent_id='Test_Agent'
)
def test_save_method_with_multiple_agents(mem0_storage_with_mocked_config):
mem0_storage, _, _ = mem0_storage_with_mocked_config
mem0_storage.crew.agents = [MagicMock(role="Test Agent"), MagicMock(role="Test Agent 2"), MagicMock(role="Test Agent 3")]
mem0_storage.memory.add = MagicMock()
test_value = "This is a test memory"
test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with(
[{"role": "assistant" , "content": test_value}],
infer=True,
metadata={"type": "short_term", "key": "value"},
run_id="my_run_id",
user_id="test_user",
agent_id='Test_Agent_Test_Agent_2_Test_Agent_3'
) )
@@ -209,13 +231,13 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co
"""Test save method for different memory types""" """Test save method for different memory types"""
mem0_storage = mem0_storage_with_memory_client_using_config_from_crew mem0_storage = mem0_storage_with_memory_client_using_config_from_crew
mem0_storage.memory.add = MagicMock() mem0_storage.memory.add = MagicMock()
# Test short_term memory type (already set in fixture) # Test short_term memory type (already set in fixture)
test_value = "This is a test memory" test_value = "This is a test memory"
test_metadata = {"key": "value"} test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata) mem0_storage.save(test_value, test_metadata)
mem0_storage.memory.add.assert_called_once_with( mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': test_value}], [{'role': 'assistant' , 'content': test_value}],
infer=True, infer=True,
@@ -224,7 +246,9 @@ def test_save_method_with_memory_client(mem0_storage_with_memory_client_using_co
run_id="my_run_id", run_id="my_run_id",
includes="include1", includes="include1",
excludes="exclude1", excludes="exclude1",
output_format='v1.1' output_format='v1.1',
user_id='test_user',
agent_id='Test_Agent'
) )
@@ -237,10 +261,10 @@ def test_search_method_with_memory_oss(mem0_storage_with_mocked_config):
results = mem0_storage.search("test query", limit=5, score_threshold=0.5) results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with( mem0_storage.memory.search.assert_called_once_with(
query="test query", query="test query",
limit=5, limit=5,
user_id="test_user", user_id="test_user",
filters={'AND': [{'run_id': 'my_run_id'}]}, filters={'AND': [{'run_id': 'my_run_id'}]},
threshold=0.5 threshold=0.5
) )
@@ -257,8 +281,8 @@ def test_search_method_with_memory_client(mem0_storage_with_memory_client_using_
results = mem0_storage.search("test query", limit=5, score_threshold=0.5) results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with( mem0_storage.memory.search.assert_called_once_with(
query="test query", query="test query",
limit=5, limit=5,
metadata={"type": "short_term"}, metadata={"type": "short_term"},
user_id="test_user", user_id="test_user",
version='v2', version='v2',
@@ -286,4 +310,56 @@ def test_mem0_storage_default_infer_value(mock_mem0_memory_client):
) )
mem0_storage = Mem0Storage(type="short_term", crew=crew) mem0_storage = Mem0Storage(type="short_term", crew=crew)
assert mem0_storage.infer is True assert mem0_storage.infer is True
def test_save_memory_using_agent_entity(mock_mem0_memory_client):
config = {
"agent_id": "agent-123",
}
mock_memory = MagicMock(spec=Memory)
with patch.object(Memory, "__new__", return_value=mock_memory):
mem0_storage = Mem0Storage(type="external", config=config)
mem0_storage.save("test memory", {"key": "value"})
mem0_storage.memory.add.assert_called_once_with(
[{'role': 'assistant' , 'content': 'test memory'}],
infer=True,
metadata={"type": "external", "key": "value"},
agent_id="agent-123",
)
def test_search_method_with_agent_entity():
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123"})
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
filters={"AND": [{"agent_id": "agent-123"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["content"] == "Result 1"
def test_search_method_with_agent_id_and_user_id():
mem0_storage = Mem0Storage(type="external", config={"agent_id": "agent-123", "user_id": "user-123"})
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mem0_storage.memory.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mem0_storage.memory.search.assert_called_once_with(
query="test query",
limit=5,
user_id='user-123',
filters={"OR": [{"user_id": "user-123"}, {"agent_id": "agent-123"}]},
threshold=0.5,
)
assert len(results) == 2
assert results[0]["content"] == "Result 1"