From 81bb54550ccfd5545c96f84d63fc8ff7feaa0960 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 13 Mar 2025 16:16:20 +0000 Subject: [PATCH] Fix #2364: Allow UserMemory to work with custom providers Co-Authored-By: Joe Moura --- .../memory/contextual/contextual_memory.py | 9 ++- src/crewai/memory/user/user_memory.py | 45 ++++++++++++--- tests/memory/user_memory_test.py | 57 +++++++++++++++++++ 3 files changed, 99 insertions(+), 12 deletions(-) create mode 100644 tests/memory/user_memory_test.py diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index cdb9cf836..3ecd59ee1 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -35,8 +35,7 @@ class ContextualMemory: context.append(self._fetch_ltm_context(task.description)) context.append(self._fetch_stm_context(query)) context.append(self._fetch_entity_context(query)) - if self.memory_provider == "mem0": - context.append(self._fetch_user_context(query)) + context.append(self._fetch_user_context(query)) return "\n".join(filter(None, context)) def _fetch_stm_context(self, query) -> str: @@ -98,7 +97,11 @@ class ContextualMemory: if not user_memories: return "" + # Check if the memory provider is mem0 by looking at the storage type + is_mem0 = hasattr(self.um.storage, "__class__") and self.um.storage.__class__.__name__ == "Mem0Storage" + formatted_memories = "\n".join( - f"- {result['memory']}" for result in user_memories + f"- {result['memory'] if is_mem0 else result['context']}" + for result in user_memories ) return f"User memories/preferences:\n{formatted_memories}" diff --git a/src/crewai/memory/user/user_memory.py b/src/crewai/memory/user/user_memory.py index 24e5fe035..163be2097 100644 --- a/src/crewai/memory/user/user_memory.py +++ b/src/crewai/memory/user/user_memory.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional from crewai.memory.memory import Memory +from crewai.memory.storage.rag_storage import RAGStorage class UserMemory(Memory): @@ -11,14 +12,34 @@ class UserMemory(Memory): MemoryItem instances. """ - def __init__(self, crew=None): - try: - from crewai.memory.storage.mem0_storage import Mem0Storage - except ImportError: - raise ImportError( - "Mem0 is not installed. Please install it with `pip install mem0ai`." + def __init__(self, crew=None, embedder_config=None, storage=None, path=None, memory_config=None): + # Get memory provider from crew or directly from memory_config + memory_provider = None + if hasattr(crew, "memory_config") and crew.memory_config is not None: + memory_provider = crew.memory_config.get("provider") + elif memory_config is not None: + memory_provider = memory_config.get("provider") + + if memory_provider == "mem0": + try: + from crewai.memory.storage.mem0_storage import Mem0Storage + except ImportError: + raise ImportError( + "Mem0 is not installed. Please install it with `pip install mem0ai`." + ) + storage = Mem0Storage(type="user", crew=crew) + else: + storage = ( + storage + if storage + else RAGStorage( + type="user", + allow_reset=True, + embedder_config=embedder_config, + crew=crew, + path=path, + ) ) - storage = Mem0Storage(type="user", crew=crew) super().__init__(storage) def save( @@ -27,9 +48,15 @@ class UserMemory(Memory): metadata: Optional[Dict[str, Any]] = None, agent: Optional[str] = None, ) -> None: - # TODO: Change this function since we want to take care of the case where we save memories for the usr - data = f"Remember the details about the user: {value}" + if self._is_mem0_storage(): + data = f"Remember the details about the user: {value}" + else: + data = value super().save(data, metadata) + + def _is_mem0_storage(self) -> bool: + """Check if the storage is Mem0Storage by checking its class name.""" + return hasattr(self.storage, "__class__") and self.storage.__class__.__name__ == "Mem0Storage" def search( self, diff --git a/tests/memory/user_memory_test.py b/tests/memory/user_memory_test.py new file mode 100644 index 000000000..1c36b6e7a --- /dev/null +++ b/tests/memory/user_memory_test.py @@ -0,0 +1,57 @@ +import pytest +from unittest.mock import patch + +from crewai.memory.user.user_memory import UserMemory +from crewai.memory.storage.rag_storage import RAGStorage + + +@patch('crewai.memory.storage.mem0_storage.Mem0Storage') +@patch('crewai.memory.storage.mem0_storage.MemoryClient') +def test_user_memory_provider_selection(mock_memory_client, mock_mem0_storage): + """Test that UserMemory selects the correct storage provider based on config.""" + # Setup - Mock Mem0Storage to avoid API key requirement + mock_mem0_storage.return_value = mock_mem0_storage + + # Test with mem0 provider + with patch('crewai.memory.user.user_memory.RAGStorage'): + # Create UserMemory with mem0 provider + memory_config = {"provider": "mem0"} + user_memory = UserMemory(memory_config=memory_config) + + # Verify Mem0Storage was used + mock_mem0_storage.assert_called_once() + + # Reset mocks + mock_mem0_storage.reset_mock() + + # Test with default provider (RAGStorage) + with patch('crewai.memory.user.user_memory.RAGStorage', return_value=mock_mem0_storage) as mock_rag: + # Create UserMemory with no provider specified + user_memory = UserMemory() + + # Verify RAGStorage was used + mock_rag.assert_called_once() + + +@patch('crewai.memory.user.user_memory.UserMemory._is_mem0_storage') +def test_user_memory_save_formatting(mock_is_mem0): + """Test that UserMemory formats data correctly based on provider.""" + # Test with mem0 provider + mock_is_mem0.return_value = True + with patch('crewai.memory.memory.Memory.save') as mock_save: + user_memory = UserMemory() + user_memory.save("test data") + + # Verify data was formatted for mem0 + args, _ = mock_save.call_args + assert "Remember the details about the user: test data" in args[0] + + # Test with RAG provider + mock_is_mem0.return_value = False + with patch('crewai.memory.memory.Memory.save') as mock_save: + user_memory = UserMemory() + user_memory.save("test data") + + # Verify data was not formatted + args, _ = mock_save.call_args + assert args[0] == "test data"