diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index 3ecd59ee1..4f395366e 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -96,12 +96,9 @@ class ContextualMemory: user_memories = self.um.search(query) if not user_memories: return "" - - # Check if the memory provider is mem0 by looking at the storage type - is_mem0 = hasattr(self.um.storage, "__class__") and self.um.storage.__class__.__name__ == "Mem0Storage" formatted_memories = "\n".join( - f"- {result['memory'] if is_mem0 else result['context']}" + f"- {result['memory'] if self.um._memory_provider == 'mem0' else result['context']}" for result in user_memories ) return f"User memories/preferences:\n{formatted_memories}" diff --git a/src/crewai/memory/user/user_memory.py b/src/crewai/memory/user/user_memory.py index 163be2097..325df51cc 100644 --- a/src/crewai/memory/user/user_memory.py +++ b/src/crewai/memory/user/user_memory.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional + +from pydantic import PrivateAttr from crewai.memory.memory import Memory from crewai.memory.storage.rag_storage import RAGStorage @@ -12,10 +14,29 @@ class UserMemory(Memory): MemoryItem instances. """ - def __init__(self, crew=None, embedder_config=None, storage=None, path=None, memory_config=None): + _memory_provider: Optional[str] = PrivateAttr() + + def __init__( + self, + crew=None, + embedder_config: Optional[Dict[str, Any]] = None, + storage: Optional[Any] = None, + path: Optional[str] = None, + memory_config: Optional[Dict[str, Any]] = None + ): + """ + Initialize UserMemory with the specified storage provider. + + Args: + crew: Optional crew object that may contain memory configuration + embedder_config: Optional configuration for the embedder + storage: Optional pre-configured storage instance + path: Optional path for storage + memory_config: Optional explicit memory configuration + """ # Get memory provider from crew or directly from memory_config memory_provider = None - if hasattr(crew, "memory_config") and crew.memory_config is not None: + if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: memory_provider = crew.memory_config.get("provider") elif memory_config is not None: memory_provider = memory_config.get("provider") @@ -40,33 +61,55 @@ class UserMemory(Memory): path=path, ) ) - super().__init__(storage) + super().__init__(storage=storage) + self._memory_provider = memory_provider def save( self, - value, + value: Any, metadata: Optional[Dict[str, Any]] = None, agent: Optional[str] = None, ) -> None: - if self._is_mem0_storage(): + """ + Save user memory data with appropriate formatting based on the storage provider. + + Args: + value: The data to save + metadata: Optional metadata to associate with the memory + agent: Optional agent name to associate with the memory + """ + if self._memory_provider == "mem0": data = f"Remember the details about the user: {value}" else: data = value - super().save(data, metadata) - - def _is_mem0_storage(self) -> bool: - """Check if the storage is Mem0Storage by checking its class name.""" - return hasattr(self.storage, "__class__") and self.storage.__class__.__name__ == "Mem0Storage" + super().save(data, metadata, agent) def search( self, query: str, limit: int = 3, score_threshold: float = 0.35, - ): - results = self.storage.search( + ) -> List[Any]: + """ + Search for user memories that match the query. + + Args: + query: The search query + limit: Maximum number of results to return + score_threshold: Minimum similarity score for results + + Returns: + List of matching memory items + """ + return self.storage.search( query=query, limit=limit, score_threshold=score_threshold, ) - return results + + def reset(self) -> None: + """Reset the user memory storage.""" + try: + self.storage.reset() + except Exception as e: + raise Exception(f"An error occurred while resetting the user memory: {e}") diff --git a/tests/memory/user_memory_test.py b/tests/memory/user_memory_test.py index e1850209d..1df7f194c 100644 --- a/tests/memory/user_memory_test.py +++ b/tests/memory/user_memory_test.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import patch, PropertyMock import pytest @@ -34,11 +34,11 @@ def test_user_memory_provider_selection(mock_memory_client, mock_mem0_storage): mock_rag.assert_called_once() -@patch('crewai.memory.user.user_memory.UserMemory._is_mem0_storage') -def test_user_memory_save_formatting(mock_is_mem0): +@patch('crewai.memory.user.user_memory.UserMemory._memory_provider', new_callable=PropertyMock) +def test_user_memory_save_formatting(mock_memory_provider): """Test that UserMemory formats data correctly based on provider.""" # Test with mem0 provider - mock_is_mem0.return_value = True + mock_memory_provider.return_value = "mem0" with patch('crewai.memory.memory.Memory.save') as mock_save: user_memory = UserMemory() user_memory.save("test data") @@ -48,7 +48,7 @@ def test_user_memory_save_formatting(mock_is_mem0): assert "Remember the details about the user: test data" in args[0] # Test with RAG provider - mock_is_mem0.return_value = False + mock_memory_provider.return_value = None with patch('crewai.memory.memory.Memory.save') as mock_save: user_memory = UserMemory() user_memory.save("test data")