Fix Memory OSS compatibility and implement code review suggestions

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-07 21:53:07 +00:00
parent 85ff023b76
commit c53cd2775a
2 changed files with 258 additions and 33 deletions

View File

@@ -12,8 +12,26 @@ class Mem0Storage(Storage):
Supports Mem0 v2 API with run_id for associating memories with specific conversation Supports Mem0 v2 API with run_id for associating memories with specific conversation
sessions. By default, uses v2 API which is recommended for better context management. sessions. By default, uses v2 API which is recommended for better context management.
Args:
type: The type of memory storage ("user", "short_term", "long_term", "entities", "external")
crew: The crew instance this storage is associated with
config: Optional configuration dictionary that overrides crew.memory_config
Configuration options:
version: API version to use ("v1.1" or "v2", defaults to "v2")
run_id: Optional session identifier for associating memories with specific conversations
api_key: Mem0 API key (defaults to MEM0_API_KEY environment variable)
user_id: User identifier (required for "user" memory type)
org_id: Optional organization ID for Mem0 API
project_id: Optional project ID for Mem0 API
local_mem0_config: Optional configuration for local Mem0 instance
""" """
SUPPORTED_VERSIONS = ["v1.1", "v2"]
DEFAULT_VERSION = "v2"
def __init__(self, type, crew=None, config=None): def __init__(self, type, crew=None, config=None):
super().__init__() super().__init__()
supported_types = ["user", "short_term", "long_term", "entities", "external"] supported_types = ["user", "short_term", "long_term", "entities", "external"]
@@ -30,9 +48,10 @@ class Mem0Storage(Storage):
self.memory_config = self.config or getattr(crew, "memory_config", {}) or {} self.memory_config = self.config or getattr(crew, "memory_config", {}) or {}
config = self._get_config() config = self._get_config()
self.version = config.get("version", "v2") self.version = config.get("version", self.DEFAULT_VERSION)
self.run_id = config.get("run_id") self.run_id = config.get("run_id")
self._validate_config()
# User ID is required for user memory type "user" since it's used as a unique identifier for the user. # User ID is required for user memory type "user" since it's used as a unique identifier for the user.
user_id = self._get_user_id() user_id = self._get_user_id()
@@ -60,16 +79,70 @@ class Mem0Storage(Storage):
else: else:
self.memory = Memory() self.memory = Memory()
def _validate_config(self) -> None:
"""
Validate configuration parameters.
Raises:
ValueError: If the version is not supported
"""
if self.version not in self.SUPPORTED_VERSIONS:
raise ValueError(
f"Unsupported version: {self.version}. "
f"Please use one of: {', '.join(self.SUPPORTED_VERSIONS)}"
)
if self.run_id is not None and not isinstance(self.run_id, str):
raise ValueError("run_id must be a string")
def _build_params(self, base_params: Dict[str, Any], method: str = "add") -> Dict[str, Any]:
"""
Centralize parameter building for API calls.
Args:
base_params: Base parameters to build upon
method: The method being called ("add" or "search")
Returns:
Dict[str, Any]: Complete parameters for API call
"""
params = base_params.copy()
# Add version and run_id for MemoryClient
if isinstance(self.memory, MemoryClient):
params["version"] = self.version
if self.run_id:
params["run_id"] = self.run_id
elif isinstance(self.memory, Memory) and method == "search" and "metadata" in params:
del params["metadata"]
return params
def _sanitize_role(self, role: str) -> str: def _sanitize_role(self, role: str) -> str:
""" """
Sanitizes agent roles to ensure valid directory names. Sanitizes agent roles to ensure valid directory names.
Args:
role: The role name to sanitize
Returns:
str: Sanitized role name
""" """
return role.replace("\n", "").replace(" ", "_").replace("/", "_") return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def save(self, value: Any, metadata: Dict[str, Any]) -> None: def save(self, value: Any, metadata: Dict[str, Any]) -> None:
"""
Save a memory item.
Args:
value: The memory content to save
metadata: Additional metadata for the memory
"""
user_id = self._get_user_id() user_id = self._get_user_id()
agent_name = self._get_agent_name() agent_name = self._get_agent_name()
params = None params = None
if self.memory_type == "short_term": if self.memory_type == "short_term":
params = { params = {
"agent_id": agent_name, "agent_id": agent_name,
@@ -96,12 +169,7 @@ class Mem0Storage(Storage):
} }
if params: if params:
if isinstance(self.memory, MemoryClient): params = self._build_params(params, method="add")
params["version"] = self.version
if self.run_id:
params["run_id"] = self.run_id
self.memory.add(value, **params) self.memory.add(value, **params)
def search( def search(
@@ -110,42 +178,61 @@ class Mem0Storage(Storage):
limit: int = 3, limit: int = 3,
score_threshold: float = 0.35, score_threshold: float = 0.35,
) -> List[Any]: ) -> List[Any]:
params = {"query": query, "limit": limit} """
Search for memories.
Args:
query: The search query
limit: Maximum number of results to return
score_threshold: Minimum score for results to be included
Returns:
List[Any]: List of memory items that match the query
"""
base_params = {"query": query, "limit": limit}
if user_id := self._get_user_id(): if user_id := self._get_user_id():
params["user_id"] = user_id base_params["user_id"] = user_id
agent_name = self._get_agent_name() agent_name = self._get_agent_name()
if self.memory_type == "short_term": if self.memory_type == "short_term":
params["agent_id"] = agent_name base_params["agent_id"] = agent_name
params["metadata"] = {"type": "short_term"} base_params["metadata"] = {"type": "short_term"}
elif self.memory_type == "long_term": elif self.memory_type == "long_term":
params["agent_id"] = agent_name base_params["agent_id"] = agent_name
params["metadata"] = {"type": "long_term"} base_params["metadata"] = {"type": "long_term"}
elif self.memory_type == "entities": elif self.memory_type == "entities":
params["agent_id"] = agent_name base_params["agent_id"] = agent_name
params["metadata"] = {"type": "entity"} base_params["metadata"] = {"type": "entity"}
elif self.memory_type == "external": elif self.memory_type == "external":
params["agent_id"] = agent_name base_params["agent_id"] = agent_name
params["metadata"] = {"type": "external"} base_params["metadata"] = {"type": "external"}
# Add version and run_id for MemoryClient params = self._build_params(base_params, method="search")
if isinstance(self.memory, MemoryClient):
params["version"] = self.version
if self.run_id:
params["run_id"] = self.run_id
# Discard the filters for Memory (OSS version)
elif isinstance(self.memory, Memory):
if "metadata" in params:
del params["metadata"]
results = self.memory.search(**params) results = self.memory.search(**params)
return [r for r in results["results"] if r["score"] >= score_threshold]
if isinstance(results, dict) and "results" in results:
return [r for r in results["results"] if r["score"] >= score_threshold]
elif isinstance(results, list):
return [r for r in results if r["score"] >= score_threshold]
else:
return []
def _get_user_id(self) -> str: def _get_user_id(self) -> str:
"""
Get the user ID from configuration.
Returns:
str: User ID or empty string if not found
"""
return self._get_config().get("user_id", "") return self._get_config().get("user_id", "")
def _get_agent_name(self) -> str: def _get_agent_name(self) -> str:
"""
Get the agent name from the crew.
Returns:
str: Agent name or empty string if not found
"""
if not self.crew: if not self.crew:
return "" return ""
@@ -155,8 +242,17 @@ class Mem0Storage(Storage):
return agents return agents
def _get_config(self) -> Dict[str, Any]: def _get_config(self) -> Dict[str, Any]:
"""
Get the configuration from either config or memory_config.
Returns:
Dict[str, Any]: Configuration dictionary
"""
return self.config or getattr(self, "memory_config", {}).get("config", {}) or {} return self.config or getattr(self, "memory_config", {}).get("config", {}) or {}
def reset(self): def reset(self) -> None:
"""
Reset the memory.
"""
if self.memory: if self.memory:
self.memory.reset() self.memory.reset()

View File

@@ -65,24 +65,54 @@ def mem0_storage_with_run_id(mock_mem0_memory_client):
return mem0_storage, mock_mem0_memory_client return mem0_storage, mock_mem0_memory_client
@pytest.fixture
def mem0_storage_with_v1_api(mock_mem0_memory_client):
"""Fixture to create a Mem0Storage instance with v1.1 API configuration"""
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"version": "v1.1", # Explicitly set to v1.1
},
}
)
mem0_storage = Mem0Storage(type="short_term", crew=crew)
return mem0_storage, mock_mem0_memory_client
@pytest.mark.v2_api
def test_mem0_storage_v2_initialization(mem0_storage_with_v2_api): def test_mem0_storage_v2_initialization(mem0_storage_with_v2_api):
"""Test that Mem0Storage initializes correctly with v2 API configuration""" """Test that Mem0Storage initializes correctly with v2 API configuration"""
mem0_storage, _ = mem0_storage_with_v2_api mem0_storage, _ = mem0_storage_with_v2_api
assert mem0_storage.version == "v2" assert mem0_storage.version == "v2"
assert mem0_storage.run_id is None assert mem0_storage.run_id is None
@pytest.mark.v2_api
def test_mem0_storage_with_run_id_initialization(mem0_storage_with_run_id): def test_mem0_storage_with_run_id_initialization(mem0_storage_with_run_id):
"""Test that Mem0Storage initializes correctly with run_id configuration""" """Test that Mem0Storage initializes correctly with run_id configuration"""
mem0_storage, _ = mem0_storage_with_run_id mem0_storage, _ = mem0_storage_with_run_id
assert mem0_storage.version == "v2" assert mem0_storage.version == "v2"
assert mem0_storage.run_id == "test-session-123" assert mem0_storage.run_id == "test-session-123"
@pytest.mark.v1_api
def test_mem0_storage_v1_initialization(mem0_storage_with_v1_api):
"""Test that Mem0Storage initializes correctly with v1.1 API configuration"""
mem0_storage, _ = mem0_storage_with_v1_api
assert mem0_storage.version == "v1.1"
assert mem0_storage.run_id is None
@pytest.mark.v2_api
def test_save_method_with_v2_api(mem0_storage_with_v2_api): def test_save_method_with_v2_api(mem0_storage_with_v2_api):
"""Test save method with v2 API""" """Test save method with v2 API"""
mem0_storage, mock_memory_client = mem0_storage_with_v2_api mem0_storage, mock_memory_client = mem0_storage_with_v2_api
@@ -102,6 +132,7 @@ def test_save_method_with_v2_api(mem0_storage_with_v2_api):
assert call_args["metadata"] == {"type": "short_term", "key": "value"} assert call_args["metadata"] == {"type": "short_term", "key": "value"}
@pytest.mark.v2_api
def test_save_method_with_run_id(mem0_storage_with_run_id): def test_save_method_with_run_id(mem0_storage_with_run_id):
"""Test save method with run_id""" """Test save method with run_id"""
mem0_storage, mock_memory_client = mem0_storage_with_run_id mem0_storage, mock_memory_client = mem0_storage_with_run_id
@@ -121,6 +152,7 @@ def test_save_method_with_run_id(mem0_storage_with_run_id):
assert call_args["metadata"] == {"type": "short_term", "key": "value"} assert call_args["metadata"] == {"type": "short_term", "key": "value"}
@pytest.mark.v2_api
def test_search_method_with_v2_api(mem0_storage_with_v2_api): def test_search_method_with_v2_api(mem0_storage_with_v2_api):
"""Test search method with v2 API""" """Test search method with v2 API"""
mem0_storage, mock_memory_client = mem0_storage_with_v2_api mem0_storage, mock_memory_client = mem0_storage_with_v2_api
@@ -141,6 +173,7 @@ def test_search_method_with_v2_api(mem0_storage_with_v2_api):
assert results[0]["content"] == "Result 1" assert results[0]["content"] == "Result 1"
@pytest.mark.v2_api
def test_search_method_with_run_id(mem0_storage_with_run_id): def test_search_method_with_run_id(mem0_storage_with_run_id):
"""Test search method with run_id""" """Test search method with run_id"""
mem0_storage, mock_memory_client = mem0_storage_with_run_id mem0_storage, mock_memory_client = mem0_storage_with_run_id
@@ -159,3 +192,99 @@ def test_search_method_with_run_id(mem0_storage_with_run_id):
assert len(results) == 1 assert len(results) == 1
assert results[0]["content"] == "Result 1" assert results[0]["content"] == "Result 1"
@pytest.mark.v2_api
def test_search_method_with_different_result_formats(mem0_storage_with_v2_api):
"""Test search method with different result formats"""
mem0_storage, mock_memory_client = mem0_storage_with_v2_api
mock_results_dict = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mock_memory_client.search = MagicMock(return_value=mock_results_dict)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
assert len(results) == 1
assert results[0]["content"] == "Result 1"
mock_results_list = [{"score": 0.9, "content": "Result 3"}, {"score": 0.4, "content": "Result 4"}]
mock_memory_client.search = MagicMock(return_value=mock_results_list)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
assert len(results) == 1
assert results[0]["content"] == "Result 3"
mock_memory_client.search = MagicMock(return_value="unexpected format")
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
assert len(results) == 0
@pytest.mark.parametrize("run_id", [None, "", "test-123", "a" * 256])
@pytest.mark.v2_api
def test_run_id_edge_cases(mock_mem0_memory_client, run_id):
"""Test edge cases for run_id parameter"""
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"version": "v2",
"run_id": run_id,
},
}
)
if run_id == "":
mem0_storage = Mem0Storage(type="short_term", crew=crew)
assert mem0_storage.run_id == ""
mock_mem0_memory_client.add = MagicMock()
mem0_storage.save("test", {})
assert "run_id" not in mock_mem0_memory_client.add.call_args[1]
else:
mem0_storage = Mem0Storage(type="short_term", crew=crew)
assert mem0_storage.run_id == run_id
if run_id is not None:
mock_mem0_memory_client.add = MagicMock()
mem0_storage.save("test", {})
assert mock_mem0_memory_client.add.call_args[1].get("run_id") == run_id
def test_invalid_version_handling(mock_mem0_memory_client):
"""Test handling of invalid version"""
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"version": "invalid",
},
}
)
with pytest.raises(ValueError, match="Unsupported version"):
Mem0Storage(type="short_term", crew=crew)
def test_invalid_run_id_type(mock_mem0_memory_client):
"""Test handling of invalid run_id type"""
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"version": "v2",
"run_id": 123, # Not a string
},
}
)
with pytest.raises(ValueError, match="run_id must be a string"):
Mem0Storage(type="short_term", crew=crew)