mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
Improve UserMemory implementation based on code review feedback
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -96,12 +96,9 @@ class ContextualMemory:
|
||||
user_memories = self.um.search(query)
|
||||
if not user_memories:
|
||||
return ""
|
||||
|
||||
# Check if the memory provider is mem0 by looking at the storage type
|
||||
is_mem0 = hasattr(self.um.storage, "__class__") and self.um.storage.__class__.__name__ == "Mem0Storage"
|
||||
|
||||
formatted_memories = "\n".join(
|
||||
f"- {result['memory'] if is_mem0 else result['context']}"
|
||||
f"- {result['memory'] if self.um._memory_provider == 'mem0' else result['context']}"
|
||||
for result in user_memories
|
||||
)
|
||||
return f"User memories/preferences:\n{formatted_memories}"
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
@@ -12,10 +14,29 @@ class UserMemory(Memory):
|
||||
MemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None, memory_config=None):
|
||||
_memory_provider: Optional[str] = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
crew=None,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
storage: Optional[Any] = None,
|
||||
path: Optional[str] = None,
|
||||
memory_config: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
Initialize UserMemory with the specified storage provider.
|
||||
|
||||
Args:
|
||||
crew: Optional crew object that may contain memory configuration
|
||||
embedder_config: Optional configuration for the embedder
|
||||
storage: Optional pre-configured storage instance
|
||||
path: Optional path for storage
|
||||
memory_config: Optional explicit memory configuration
|
||||
"""
|
||||
# Get memory provider from crew or directly from memory_config
|
||||
memory_provider = None
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
memory_provider = crew.memory_config.get("provider")
|
||||
elif memory_config is not None:
|
||||
memory_provider = memory_config.get("provider")
|
||||
@@ -40,33 +61,55 @@ class UserMemory(Memory):
|
||||
path=path,
|
||||
)
|
||||
)
|
||||
super().__init__(storage)
|
||||
super().__init__(storage=storage)
|
||||
self._memory_provider = memory_provider
|
||||
|
||||
def save(
|
||||
self,
|
||||
value,
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
) -> None:
|
||||
if self._is_mem0_storage():
|
||||
"""
|
||||
Save user memory data with appropriate formatting based on the storage provider.
|
||||
|
||||
Args:
|
||||
value: The data to save
|
||||
metadata: Optional metadata to associate with the memory
|
||||
agent: Optional agent name to associate with the memory
|
||||
"""
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"Remember the details about the user: {value}"
|
||||
else:
|
||||
data = value
|
||||
super().save(data, metadata)
|
||||
|
||||
def _is_mem0_storage(self) -> bool:
|
||||
"""Check if the storage is Mem0Storage by checking its class name."""
|
||||
return hasattr(self.storage, "__class__") and self.storage.__class__.__name__ == "Mem0Storage"
|
||||
super().save(data, metadata, agent)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
):
|
||||
results = self.storage.search(
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Search for user memories that match the query.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
limit: Maximum number of results to return
|
||||
score_threshold: Minimum similarity score for results
|
||||
|
||||
Returns:
|
||||
List of matching memory items
|
||||
"""
|
||||
return self.storage.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return results
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the user memory storage."""
|
||||
try:
|
||||
self.storage.reset()
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while resetting the user memory: {e}")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -34,11 +34,11 @@ def test_user_memory_provider_selection(mock_memory_client, mock_mem0_storage):
|
||||
mock_rag.assert_called_once()
|
||||
|
||||
|
||||
@patch('crewai.memory.user.user_memory.UserMemory._is_mem0_storage')
|
||||
def test_user_memory_save_formatting(mock_is_mem0):
|
||||
@patch('crewai.memory.user.user_memory.UserMemory._memory_provider', new_callable=PropertyMock)
|
||||
def test_user_memory_save_formatting(mock_memory_provider):
|
||||
"""Test that UserMemory formats data correctly based on provider."""
|
||||
# Test with mem0 provider
|
||||
mock_is_mem0.return_value = True
|
||||
mock_memory_provider.return_value = "mem0"
|
||||
with patch('crewai.memory.memory.Memory.save') as mock_save:
|
||||
user_memory = UserMemory()
|
||||
user_memory.save("test data")
|
||||
@@ -48,7 +48,7 @@ def test_user_memory_save_formatting(mock_is_mem0):
|
||||
assert "Remember the details about the user: test data" in args[0]
|
||||
|
||||
# Test with RAG provider
|
||||
mock_is_mem0.return_value = False
|
||||
mock_memory_provider.return_value = None
|
||||
with patch('crewai.memory.memory.Memory.save') as mock_save:
|
||||
user_memory = UserMemory()
|
||||
user_memory.save("test data")
|
||||
|
||||
Reference in New Issue
Block a user