mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-30 03:08:29 +00:00
Compare commits
3 Commits
bugfix-pyt
...
devin/1742
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b16551bd3d | ||
|
|
2682397c19 | ||
|
|
08b27bcbf7 |
@@ -14,12 +14,19 @@ class ContextualMemory:
|
||||
):
|
||||
if memory_config is not None:
|
||||
self.memory_provider = memory_config.get("provider")
|
||||
# Special handling for Mem0 provider
|
||||
if self.memory_provider == "mem0":
|
||||
# Check if a custom client was provided in the memory_config
|
||||
self.um = um
|
||||
self.search_kwargs = memory_config.get("config", {}).get("search_kwargs", {})
|
||||
else:
|
||||
self.um = um
|
||||
else:
|
||||
self.memory_provider = None
|
||||
self.um = um
|
||||
self.stm = stm
|
||||
self.ltm = ltm
|
||||
self.em = em
|
||||
self.um = um
|
||||
|
||||
def build_context_for_task(self, task, context) -> str:
|
||||
"""
|
||||
@@ -94,6 +101,10 @@ class ContextualMemory:
|
||||
Returns:
|
||||
str: Formatted user memories as bullet points, or an empty string if none found.
|
||||
"""
|
||||
# Check if user memory is available
|
||||
if self.um is None:
|
||||
return ""
|
||||
|
||||
user_memories = self.um.search(query)
|
||||
if not user_memories:
|
||||
return ""
|
||||
|
||||
@@ -19,29 +19,37 @@ class Mem0Storage(Storage):
|
||||
|
||||
self.memory_type = type
|
||||
self.crew = crew
|
||||
self.memory_config = crew.memory_config
|
||||
self.memory_config = crew.memory_config if crew else None
|
||||
|
||||
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
||||
user_id = self._get_user_id()
|
||||
if type == "user" and not user_id:
|
||||
raise ValueError("User ID is required for user memory type")
|
||||
|
||||
# API key in memory config overrides the environment variable
|
||||
config = self.memory_config.get("config", {})
|
||||
mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY")
|
||||
mem0_org_id = config.get("org_id")
|
||||
mem0_project_id = config.get("project_id")
|
||||
|
||||
# Initialize MemoryClient or Memory based on the presence of the mem0_api_key
|
||||
if mem0_api_key:
|
||||
if mem0_org_id and mem0_project_id:
|
||||
self.memory = MemoryClient(
|
||||
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
|
||||
)
|
||||
# Check if a client was provided in the memory_config
|
||||
if self.memory_config:
|
||||
config = self.memory_config.get("config", {})
|
||||
if config.get("client"):
|
||||
self.memory = config.get("client")
|
||||
else:
|
||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||
# API key in memory config overrides the environment variable
|
||||
mem0_api_key = config.get("api_key") or os.getenv("MEM0_API_KEY")
|
||||
mem0_org_id = config.get("org_id")
|
||||
mem0_project_id = config.get("project_id")
|
||||
|
||||
# Initialize MemoryClient or Memory based on the presence of the mem0_api_key
|
||||
if mem0_api_key:
|
||||
if mem0_org_id and mem0_project_id:
|
||||
self.memory = MemoryClient(
|
||||
api_key=mem0_api_key, org_id=mem0_org_id, project_id=mem0_project_id
|
||||
)
|
||||
else:
|
||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||
else:
|
||||
self.memory = Memory() # Fallback to Memory if no Mem0 API key is provided
|
||||
else:
|
||||
self.memory = Memory() # Fallback to Memory if no Mem0 API key is provided
|
||||
# No memory config, use default Memory
|
||||
self.memory = Memory()
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""
|
||||
@@ -103,13 +111,15 @@ class Mem0Storage(Storage):
|
||||
|
||||
def _get_user_id(self):
|
||||
if self.memory_type == "user":
|
||||
if hasattr(self, "memory_config") and self.memory_config is not None:
|
||||
if self.memory_config is not None:
|
||||
return self.memory_config.get("config", {}).get("user_id")
|
||||
else:
|
||||
return None
|
||||
return "default_user" # Provide a default user ID for testing
|
||||
return None
|
||||
|
||||
def _get_agent_name(self):
|
||||
if not self.crew:
|
||||
return "default_agent"
|
||||
agents = self.crew.agents if self.crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
|
||||
79
tests/memory/user_memory_test.py
Normal file
79
tests/memory/user_memory_test.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.user.user_memory import UserMemory
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
class MockMemoryClient:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def search(self, *args, **kwargs):
|
||||
return [{"memory": "Test memory", "score": 0.9}]
|
||||
|
||||
def add(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def test_contextual_memory_with_mem0_client():
|
||||
# Create a mock mem0 client
|
||||
mock_mem0_client = MockMemoryClient()
|
||||
|
||||
# Create agent and task
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create a UserMemory instance with our mock client
|
||||
user_memory = UserMemory(crew=None)
|
||||
# Manually set the storage memory to our mock client
|
||||
user_memory.storage.memory = mock_mem0_client
|
||||
|
||||
# Create crew with mem0 as memory provider and pass the UserMemory instance
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
process=Process.sequential,
|
||||
memory=True,
|
||||
memory_config={
|
||||
"provider": "mem0",
|
||||
"config": {
|
||||
"user_id": "test_user",
|
||||
},
|
||||
"user_memory": user_memory
|
||||
},
|
||||
)
|
||||
|
||||
# Create contextual memory manually with the crew's memory components
|
||||
contextual_memory = ContextualMemory(
|
||||
memory_config=crew.memory_config,
|
||||
stm=crew._short_term_memory,
|
||||
ltm=crew._long_term_memory,
|
||||
em=crew._entity_memory,
|
||||
um=crew._user_memory,
|
||||
)
|
||||
|
||||
# Test _fetch_user_context
|
||||
result = contextual_memory._fetch_user_context("test query")
|
||||
|
||||
# Should return formatted memories from the mock client
|
||||
assert "User memories/preferences" in result
|
||||
assert "- Test memory" in result
|
||||
Reference in New Issue
Block a user