Compare commits

...

5 Commits

Author SHA1 Message Date
Devin AI
c6461d2d69 Fix import sorting in user_memory_test.py with Ruff auto-fix
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-13 16:26:46 +00:00
Devin AI
aa41cc0bdc Fix import sorting in user_memory_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-13 16:25:33 +00:00
Devin AI
f1da364e70 Improve UserMemory implementation based on code review feedback
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-13 16:24:24 +00:00
Devin AI
f7a9265e35 Fix import sorting in user_memory_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-13 16:17:51 +00:00
Devin AI
81bb54550c Fix #2364: Allow UserMemory to work with custom providers
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-13 16:16:20 +00:00
3 changed files with 148 additions and 20 deletions

View File

@@ -35,8 +35,7 @@ class ContextualMemory:
context.append(self._fetch_ltm_context(task.description))
context.append(self._fetch_stm_context(query))
context.append(self._fetch_entity_context(query))
if self.memory_provider == "mem0":
context.append(self._fetch_user_context(query))
context.append(self._fetch_user_context(query))
return "\n".join(filter(None, context))
def _fetch_stm_context(self, query) -> str:
@@ -97,8 +96,9 @@ class ContextualMemory:
user_memories = self.um.search(query)
if not user_memories:
return ""
formatted_memories = "\n".join(
f"- {result['memory']}" for result in user_memories
f"- {result['memory'] if self.um._memory_provider == 'mem0' else result['context']}"
for result in user_memories
)
return f"User memories/preferences:\n{formatted_memories}"

View File

@@ -1,6 +1,9 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from pydantic import PrivateAttr
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
class UserMemory(Memory):
@@ -11,35 +14,102 @@ class UserMemory(Memory):
MemoryItem instances.
"""
def __init__(self, crew=None):
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
_memory_provider: Optional[str] = PrivateAttr()
def __init__(
self,
crew=None,
embedder_config: Optional[Dict[str, Any]] = None,
storage: Optional[Any] = None,
path: Optional[str] = None,
memory_config: Optional[Dict[str, Any]] = None
):
"""
Initialize UserMemory with the specified storage provider.
Args:
crew: Optional crew object that may contain memory configuration
embedder_config: Optional configuration for the embedder
storage: Optional pre-configured storage instance
path: Optional path for storage
memory_config: Optional explicit memory configuration
"""
# Get memory provider from crew or directly from memory_config
memory_provider = None
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
memory_provider = crew.memory_config.get("provider")
elif memory_config is not None:
memory_provider = memory_config.get("provider")
if memory_provider == "mem0":
try:
from crewai.memory.storage.mem0_storage import Mem0Storage
except ImportError:
raise ImportError(
"Mem0 is not installed. Please install it with `pip install mem0ai`."
)
storage = Mem0Storage(type="user", crew=crew)
else:
storage = (
storage
if storage
else RAGStorage(
type="user",
allow_reset=True,
embedder_config=embedder_config,
crew=crew,
path=path,
)
)
storage = Mem0Storage(type="user", crew=crew)
super().__init__(storage)
super().__init__(storage=storage)
self._memory_provider = memory_provider
def save(
self,
value,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
# TODO: Change this function since we want to take care of the case where we save memories for the usr
data = f"Remember the details about the user: {value}"
super().save(data, metadata)
"""
Save user memory data with appropriate formatting based on the storage provider.
Args:
value: The data to save
metadata: Optional metadata to associate with the memory
agent: Optional agent name to associate with the memory
"""
if self._memory_provider == "mem0":
data = f"Remember the details about the user: {value}"
else:
data = value
super().save(data, metadata, agent)
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
):
results = self.storage.search(
) -> List[Any]:
"""
Search for user memories that match the query.
Args:
query: The search query
limit: Maximum number of results to return
score_threshold: Minimum similarity score for results
Returns:
List of matching memory items
"""
return self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
)
return results
def reset(self) -> None:
"""Reset the user memory storage."""
try:
self.storage.reset()
except Exception as e:
raise Exception(f"An error occurred while resetting the user memory: {e}")

View File

@@ -0,0 +1,58 @@
from unittest.mock import PropertyMock, patch
import pytest
from crewai.memory.storage.rag_storage import RAGStorage
from crewai.memory.user.user_memory import UserMemory
@patch('crewai.memory.storage.mem0_storage.Mem0Storage')
@patch('crewai.memory.storage.mem0_storage.MemoryClient')
def test_user_memory_provider_selection(mock_memory_client, mock_mem0_storage):
"""Test that UserMemory selects the correct storage provider based on config."""
# Setup - Mock Mem0Storage to avoid API key requirement
mock_mem0_storage.return_value = mock_mem0_storage
# Test with mem0 provider
with patch('crewai.memory.user.user_memory.RAGStorage'):
# Create UserMemory with mem0 provider
memory_config = {"provider": "mem0"}
user_memory = UserMemory(memory_config=memory_config)
# Verify Mem0Storage was used
mock_mem0_storage.assert_called_once()
# Reset mocks
mock_mem0_storage.reset_mock()
# Test with default provider (RAGStorage)
with patch('crewai.memory.user.user_memory.RAGStorage', return_value=mock_mem0_storage) as mock_rag:
# Create UserMemory with no provider specified
user_memory = UserMemory()
# Verify RAGStorage was used
mock_rag.assert_called_once()
@patch('crewai.memory.user.user_memory.UserMemory._memory_provider', new_callable=PropertyMock)
def test_user_memory_save_formatting(mock_memory_provider):
"""Test that UserMemory formats data correctly based on provider."""
# Test with mem0 provider
mock_memory_provider.return_value = "mem0"
with patch('crewai.memory.memory.Memory.save') as mock_save:
user_memory = UserMemory()
user_memory.save("test data")
# Verify data was formatted for mem0
args, _ = mock_save.call_args
assert "Remember the details about the user: test data" in args[0]
# Test with RAG provider
mock_memory_provider.return_value = None
with patch('crewai.memory.memory.Memory.save') as mock_save:
user_memory = UserMemory()
user_memory.save("test data")
# Verify data was not formatted
args, _ = mock_save.call_args
assert args[0] == "test data"