Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
b16551bd3d Fix import sorting in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-23 21:32:56 +00:00
Devin AI
2682397c19 Fix import sorting in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-23 21:31:45 +00:00
Devin AI
08b27bcbf7 Fix AttributeError in ContextualMemory when using Mem0 memory provider
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-23 21:30:06 +00:00
3 changed files with 118 additions and 18 deletions

View File

@@ -14,12 +14,19 @@ class ContextualMemory:
):
if memory_config is not None:
self.memory_provider = memory_config.get("provider")
# Special handling for Mem0 provider
if self.memory_provider == "mem0":
# Check if a custom client was provided in the memory_config
self.um = um
self.search_kwargs = memory_config.get("config", {}).get("search_kwargs", {})
else:
self.um = um
else:
self.memory_provider = None
self.um = um
self.stm = stm
self.ltm = ltm
self.em = em
self.um = um
def build_context_for_task(self, task, context) -> str:
"""
@@ -94,6 +101,10 @@ class ContextualMemory:
Returns:
str: Formatted user memories as bullet points, or an empty string if none found.
"""
# Check if user memory is available
if self.um is None:
return ""
user_memories = self.um.search(query)
if not user_memories:
return ""

View File

@@ -19,29 +19,37 @@ class Mem0Storage(Storage):
self.memory_type = type
self.crew = crew
self.memory_config = crew.memory_config
self.memory_config = crew.memory_config if crew else None
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
user_id = self._get_user_id()
if type == "user" and not user_id:
raise ValueError("User ID is required for user memory type")
# API key in memory config overrides the environment variable
config = self.memory_config.get("config", {})
mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY")
mem0_org_id = config.get("org_id")
mem0_project_id = config.get("project_id")
# Initialize MemoryClient or Memory based on the presence of the mem0_api_key
if mem0_api_key:
if mem0_org_id and mem0_project_id:
self.memory = MemoryClient(
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
)
# Check if a client was provided in the memory_config
if self.memory_config:
config = self.memory_config.get("config", {})
if config.get("client"):
self.memory = config.get("client")
else:
self.memory = MemoryClient(api_key=mem0_api_key)
# API key in memory config overrides the environment variable
mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY")
mem0_org_id = config.get("org_id")
mem0_project_id = config.get("project_id")
# Initialize MemoryClient or Memory based on the presence of the mem0_api_key
if mem0_api_key:
if mem0_org_id and mem0_project_id:
self.memory = MemoryClient(
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
)
else:
self.memory = MemoryClient(api_key=mem0_api_key)
else:
self.memory = Memory() # Fallback to Memory if no Mem0 API key is provided
else:
self.memory = Memory() # Fallback to Memory if no Mem0 API key is provided
# No memory config, use default Memory
self.memory = Memory()
def _sanitize_role(self, role: str) -> str:
"""
@@ -103,13 +111,15 @@ class Mem0Storage(Storage):
def _get_user_id(self):
if self.memory_type == "user":
if hasattr(self, "memory_config") and self.memory_config is not None:
if self.memory_config is not None:
return self.memory_config.get("config", {}).get("user_id")
else:
return None
return "default_user" # Provide a default user ID for testing
return None
def _get_agent_name(self):
if not self.crew:
return "default_agent"
agents = self.crew.agents if self.crew else []
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)

View File

@@ -0,0 +1,79 @@
from unittest.mock import MagicMock, patch
import pytest
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.memory.user.user_memory import UserMemory
from crewai.process import Process
from crewai.task import Task
class MockMemoryClient:
def __init__(self, *args, **kwargs):
pass
def search(self, *args, **kwargs):
return [{"memory": "Test memory", "score": 0.9}]
def add(self, *args, **kwargs):
pass
def test_contextual_memory_with_mem0_client():
# Create a mock mem0 client
mock_mem0_client = MockMemoryClient()
# Create agent and task
agent = Agent(
role="Researcher",
goal="Search relevant data and provide results",
backstory="You are a researcher at a leading tech think tank.",
verbose=True,
)
task = Task(
description="Perform a search on specific topics.",
expected_output="A list of relevant URLs based on the search query.",
agent=agent,
)
# Create a UserMemory instance with our mock client
user_memory = UserMemory(crew=None)
# Manually set the storage memory to our mock client
user_memory.storage.memory = mock_mem0_client
# Create crew with mem0 as memory provider and pass the UserMemory instance
crew = Crew(
agents=[agent],
tasks=[task],
process=Process.sequential,
memory=True,
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
},
"user_memory": user_memory
},
)
# Create contextual memory manually with the crew's memory components
contextual_memory = ContextualMemory(
memory_config=crew.memory_config,
stm=crew._short_term_memory,
ltm=crew._long_term_memory,
em=crew._entity_memory,
um=crew._user_memory,
)
# Test _fetch_user_context
result = contextual_memory._fetch_user_context("test query")
# Should return formatted memories from the mock client
assert "User memories/preferences" in result
assert "- Test memory" in result