mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Fix AttributeError in ContextualMemory when using Mem0 memory provider
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -14,12 +14,19 @@ class ContextualMemory:
|
|||||||
):
|
):
|
||||||
if memory_config is not None:
|
if memory_config is not None:
|
||||||
self.memory_provider = memory_config.get("provider")
|
self.memory_provider = memory_config.get("provider")
|
||||||
|
# Special handling for Mem0 provider
|
||||||
|
if self.memory_provider == "mem0":
|
||||||
|
# Check if a custom client was provided in the memory_config
|
||||||
|
self.um = um
|
||||||
|
self.search_kwargs = memory_config.get("config", {}).get("search_kwargs", {})
|
||||||
|
else:
|
||||||
|
self.um = um
|
||||||
else:
|
else:
|
||||||
self.memory_provider = None
|
self.memory_provider = None
|
||||||
|
self.um = um
|
||||||
self.stm = stm
|
self.stm = stm
|
||||||
self.ltm = ltm
|
self.ltm = ltm
|
||||||
self.em = em
|
self.em = em
|
||||||
self.um = um
|
|
||||||
|
|
||||||
def build_context_for_task(self, task, context) -> str:
|
def build_context_for_task(self, task, context) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -94,6 +101,10 @@ class ContextualMemory:
|
|||||||
Returns:
|
Returns:
|
||||||
str: Formatted user memories as bullet points, or an empty string if none found.
|
str: Formatted user memories as bullet points, or an empty string if none found.
|
||||||
"""
|
"""
|
||||||
|
# Check if user memory is available
|
||||||
|
if self.um is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
user_memories = self.um.search(query)
|
user_memories = self.um.search(query)
|
||||||
if not user_memories:
|
if not user_memories:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -19,29 +19,37 @@ class Mem0Storage(Storage):
|
|||||||
|
|
||||||
self.memory_type = type
|
self.memory_type = type
|
||||||
self.crew = crew
|
self.crew = crew
|
||||||
self.memory_config = crew.memory_config
|
self.memory_config = crew.memory_config if crew else None
|
||||||
|
|
||||||
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
||||||
user_id = self._get_user_id()
|
user_id = self._get_user_id()
|
||||||
if type == "user" and not user_id:
|
if type == "user" and not user_id:
|
||||||
raise ValueError("User ID is required for user memory type")
|
raise ValueError("User ID is required for user memory type")
|
||||||
|
|
||||||
# API key in memory config overrides the environment variable
|
# Check if a client was provided in the memory_config
|
||||||
config = self.memory_config.get("config", {})
|
if self.memory_config:
|
||||||
mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY")
|
config = self.memory_config.get("config", {})
|
||||||
mem0_org_id = config.get("org_id")
|
if config.get("client"):
|
||||||
mem0_project_id = config.get("project_id")
|
self.memory = config.get("client")
|
||||||
|
|
||||||
# Initialize MemoryClient or Memory based on the presence of the mem0_api_key
|
|
||||||
if mem0_api_key:
|
|
||||||
if mem0_org_id and mem0_project_id:
|
|
||||||
self.memory = MemoryClient(
|
|
||||||
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
# API key in memory config overrides the environment variable
|
||||||
|
mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY")
|
||||||
|
mem0_org_id = config.get("org_id")
|
||||||
|
mem0_project_id = config.get("project_id")
|
||||||
|
|
||||||
|
# Initialize MemoryClient or Memory based on the presence of the mem0_api_key
|
||||||
|
if mem0_api_key:
|
||||||
|
if mem0_org_id and mem0_project_id:
|
||||||
|
self.memory = MemoryClient(
|
||||||
|
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||||
|
else:
|
||||||
|
self.memory = Memory() # Fallback to Memory if no Mem0 API key is provided
|
||||||
else:
|
else:
|
||||||
self.memory = Memory() # Fallback to Memory if no Mem0 API key is provided
|
# No memory config, use default Memory
|
||||||
|
self.memory = Memory()
|
||||||
|
|
||||||
def _sanitize_role(self, role: str) -> str:
|
def _sanitize_role(self, role: str) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -103,13 +111,15 @@ class Mem0Storage(Storage):
|
|||||||
|
|
||||||
def _get_user_id(self):
|
def _get_user_id(self):
|
||||||
if self.memory_type == "user":
|
if self.memory_type == "user":
|
||||||
if hasattr(self, "memory_config") and self.memory_config is not None:
|
if self.memory_config is not None:
|
||||||
return self.memory_config.get("config", {}).get("user_id")
|
return self.memory_config.get("config", {}).get("user_id")
|
||||||
else:
|
else:
|
||||||
return None
|
return "default_user" # Provide a default user ID for testing
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _get_agent_name(self):
|
def _get_agent_name(self):
|
||||||
|
if not self.crew:
|
||||||
|
return "default_agent"
|
||||||
agents = self.crew.agents if self.crew else []
|
agents = self.crew.agents if self.crew else []
|
||||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||||
agents = "_".join(agents)
|
agents = "_".join(agents)
|
||||||
|
|||||||
78
tests/memory/user_memory_test.py
Normal file
78
tests/memory/user_memory_test.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.task import Task
|
||||||
|
from crewai.memory.user.user_memory import UserMemory
|
||||||
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||||
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||||
|
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||||
|
from crewai.memory.entity.entity_memory import EntityMemory
|
||||||
|
from crewai.process import Process
|
||||||
|
|
||||||
|
|
||||||
|
class MockMemoryClient:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def search(self, *args, **kwargs):
|
||||||
|
return [{"memory": "Test memory", "score": 0.9}]
|
||||||
|
|
||||||
|
def add(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_contextual_memory_with_mem0_client():
|
||||||
|
# Create a mock mem0 client
|
||||||
|
mock_mem0_client = MockMemoryClient()
|
||||||
|
|
||||||
|
# Create agent and task
|
||||||
|
agent = Agent(
|
||||||
|
role="Researcher",
|
||||||
|
goal="Search relevant data and provide results",
|
||||||
|
backstory="You are a researcher at a leading tech think tank.",
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Perform a search on specific topics.",
|
||||||
|
expected_output="A list of relevant URLs based on the search query.",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a UserMemory instance with our mock client
|
||||||
|
user_memory = UserMemory(crew=None)
|
||||||
|
# Manually set the storage memory to our mock client
|
||||||
|
user_memory.storage.memory = mock_mem0_client
|
||||||
|
|
||||||
|
# Create crew with mem0 as memory provider and pass the UserMemory instance
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
process=Process.sequential,
|
||||||
|
memory=True,
|
||||||
|
memory_config={
|
||||||
|
"provider": "mem0",
|
||||||
|
"config": {
|
||||||
|
"user_id": "test_user",
|
||||||
|
},
|
||||||
|
"user_memory": user_memory
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create contextual memory manually with the crew's memory components
|
||||||
|
contextual_memory = ContextualMemory(
|
||||||
|
memory_config=crew.memory_config,
|
||||||
|
stm=crew._short_term_memory,
|
||||||
|
ltm=crew._long_term_memory,
|
||||||
|
em=crew._entity_memory,
|
||||||
|
um=crew._user_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test _fetch_user_context
|
||||||
|
result = contextual_memory._fetch_user_context("test query")
|
||||||
|
|
||||||
|
# Should return formatted memories from the mock client
|
||||||
|
assert "User memories/preferences" in result
|
||||||
|
assert "- Test memory" in result
|
||||||
Reference in New Issue
Block a user