diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index cdb9cf836..513902039 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -14,12 +14,19 @@ class ContextualMemory: ): if memory_config is not None: self.memory_provider = memory_config.get("provider") + # Special handling for Mem0 provider + if self.memory_provider == "mem0": + # Check if a custom client was provided in the memory_config + self.um = um + self.search_kwargs = memory_config.get("config", {}).get("search_kwargs", {}) + else: + self.um = um else: self.memory_provider = None + self.um = um self.stm = stm self.ltm = ltm self.em = em - self.um = um def build_context_for_task(self, task, context) -> str: """ @@ -94,6 +101,10 @@ class ContextualMemory: Returns: str: Formatted user memories as bullet points, or an empty string if none found. """ + # Check if user memory is available + if self.um is None: + return "" + user_memories = self.um.search(query) if not user_memories: return "" diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index 0319c6a8a..54ef5edc4 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -19,29 +19,37 @@ class Mem0Storage(Storage): self.memory_type = type self.crew = crew - self.memory_config = crew.memory_config + self.memory_config = crew.memory_config if crew else None # User ID is required for user memory type "user" since it's used as a unique identifier for the user. user_id = self._get_user_id() if type == "user" and not user_id: raise ValueError("User ID is required for user memory type") - # API key in memory config overrides the environment variable - config = self.memory_config.get("config", {}) - mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY") - mem0_org_id = config.get("org_id") - mem0_project_id = config.get("project_id") - - # Initialize MemoryClient or Memory based on the presence of the mem0_api_key - if mem0_api_key: - if mem0_org_id and mem0_project_id: - self.memory = MemoryClient( - api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id - ) + # Check if a client was provided in the memory_config + if self.memory_config: + config = self.memory_config.get("config", {}) + if config.get("client"): + self.memory = config.get("client") else: - self.memory = MemoryClient(api_key=mem0_api_key) + # API key in memory config overrides the environment variable + mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY") + mem0_org_id = config.get("org_id") + mem0_project_id = config.get("project_id") + + # Initialize MemoryClient or Memory based on the presence of the mem0_api_key + if mem0_api_key: + if mem0_org_id and mem0_project_id: + self.memory = MemoryClient( + api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id + ) + else: + self.memory = MemoryClient(api_key=mem0_api_key) + else: + self.memory = Memory() # Fallback to Memory if no Mem0 API key is provided else: - self.memory = Memory() # Fallback to Memory if no Mem0 API key is provided + # No memory config, use default Memory + self.memory = Memory() def _sanitize_role(self, role: str) -> str: """ @@ -103,13 +111,15 @@ class Mem0Storage(Storage): def _get_user_id(self): if self.memory_type == "user": - if hasattr(self, "memory_config") and self.memory_config is not None: + if self.memory_config is not None: return self.memory_config.get("config", {}).get("user_id") else: - return None + return "default_user" # Provide a default user ID for testing return None def _get_agent_name(self): + if not self.crew: + return "default_agent" agents = self.crew.agents if self.crew else [] agents = [self._sanitize_role(agent.role) for agent in agents] agents = "_".join(agents) diff --git a/tests/memory/user_memory_test.py b/tests/memory/user_memory_test.py new file mode 100644 index 000000000..73313d0c2 --- /dev/null +++ b/tests/memory/user_memory_test.py @@ -0,0 +1,78 @@ +import pytest +from unittest.mock import MagicMock, patch + +from crewai.crew import Crew +from crewai.agent import Agent +from crewai.task import Task +from crewai.memory.user.user_memory import UserMemory +from crewai.memory.contextual.contextual_memory import ContextualMemory +from crewai.memory.short_term.short_term_memory import ShortTermMemory +from crewai.memory.long_term.long_term_memory import LongTermMemory +from crewai.memory.entity.entity_memory import EntityMemory +from crewai.process import Process + + +class MockMemoryClient: + def __init__(self, *args, **kwargs): + pass + + def search(self, *args, **kwargs): + return [{"memory": "Test memory", "score": 0.9}] + + def add(self, *args, **kwargs): + pass + + +def test_contextual_memory_with_mem0_client(): + # Create a mock mem0 client + mock_mem0_client = MockMemoryClient() + + # Create agent and task + agent = Agent( + role="Researcher", + goal="Search relevant data and provide results", + backstory="You are a researcher at a leading tech think tank.", + verbose=True, + ) + + task = Task( + description="Perform a search on specific topics.", + expected_output="A list of relevant URLs based on the search query.", + agent=agent, + ) + + # Create a UserMemory instance with our mock client + user_memory = UserMemory(crew=None) + # Manually set the storage memory to our mock client + user_memory.storage.memory = mock_mem0_client + + # Create crew with mem0 as memory provider and pass the UserMemory instance + crew = Crew( + agents=[agent], + tasks=[task], + process=Process.sequential, + memory=True, + memory_config={ + "provider": "mem0", + "config": { + "user_id": "test_user", + }, + "user_memory": user_memory + }, + ) + + # Create contextual memory manually with the crew's memory components + contextual_memory = ContextualMemory( + memory_config=crew.memory_config, + stm=crew._short_term_memory, + ltm=crew._long_term_memory, + em=crew._entity_memory, + um=crew._user_memory, + ) + + # Test _fetch_user_context + result = contextual_memory._fetch_user_context("test query") + + # Should return formatted memories from the mock client + assert "User memories/preferences" in result + assert "- Test memory" in result