mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 00:58:13 +00:00
Compare commits
5 Commits
llm-event-
...
devin/1741
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6461d2d69 | ||
|
|
aa41cc0bdc | ||
|
|
f1da364e70 | ||
|
|
f7a9265e35 | ||
|
|
81bb54550c |
@@ -35,8 +35,7 @@ class ContextualMemory:
|
|||||||
context.append(self._fetch_ltm_context(task.description))
|
context.append(self._fetch_ltm_context(task.description))
|
||||||
context.append(self._fetch_stm_context(query))
|
context.append(self._fetch_stm_context(query))
|
||||||
context.append(self._fetch_entity_context(query))
|
context.append(self._fetch_entity_context(query))
|
||||||
if self.memory_provider == "mem0":
|
context.append(self._fetch_user_context(query))
|
||||||
context.append(self._fetch_user_context(query))
|
|
||||||
return "\n".join(filter(None, context))
|
return "\n".join(filter(None, context))
|
||||||
|
|
||||||
def _fetch_stm_context(self, query) -> str:
|
def _fetch_stm_context(self, query) -> str:
|
||||||
@@ -99,6 +98,7 @@ class ContextualMemory:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
formatted_memories = "\n".join(
|
formatted_memories = "\n".join(
|
||||||
f"- {result['memory']}" for result in user_memories
|
f"- {result['memory'] if self.um._memory_provider == 'mem0' else result['context']}"
|
||||||
|
for result in user_memories
|
||||||
)
|
)
|
||||||
return f"User memories/preferences:\n{formatted_memories}"
|
return f"User memories/preferences:\n{formatted_memories}"
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import PrivateAttr
|
||||||
|
|
||||||
from crewai.memory.memory import Memory
|
from crewai.memory.memory import Memory
|
||||||
|
from crewai.memory.storage.rag_storage import RAGStorage
|
||||||
|
|
||||||
|
|
||||||
class UserMemory(Memory):
|
class UserMemory(Memory):
|
||||||
@@ -11,35 +14,102 @@ class UserMemory(Memory):
|
|||||||
MemoryItem instances.
|
MemoryItem instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, crew=None):
|
_memory_provider: Optional[str] = PrivateAttr()
|
||||||
try:
|
|
||||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
def __init__(
|
||||||
except ImportError:
|
self,
|
||||||
raise ImportError(
|
crew=None,
|
||||||
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
embedder_config: Optional[Dict[str, Any]] = None,
|
||||||
|
storage: Optional[Any] = None,
|
||||||
|
path: Optional[str] = None,
|
||||||
|
memory_config: Optional[Dict[str, Any]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize UserMemory with the specified storage provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
crew: Optional crew object that may contain memory configuration
|
||||||
|
embedder_config: Optional configuration for the embedder
|
||||||
|
storage: Optional pre-configured storage instance
|
||||||
|
path: Optional path for storage
|
||||||
|
memory_config: Optional explicit memory configuration
|
||||||
|
"""
|
||||||
|
# Get memory provider from crew or directly from memory_config
|
||||||
|
memory_provider = None
|
||||||
|
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||||
|
memory_provider = crew.memory_config.get("provider")
|
||||||
|
elif memory_config is not None:
|
||||||
|
memory_provider = memory_config.get("provider")
|
||||||
|
|
||||||
|
if memory_provider == "mem0":
|
||||||
|
try:
|
||||||
|
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Mem0 is not installed. Please install it with `pip install mem0ai`."
|
||||||
|
)
|
||||||
|
storage = Mem0Storage(type="user", crew=crew)
|
||||||
|
else:
|
||||||
|
storage = (
|
||||||
|
storage
|
||||||
|
if storage
|
||||||
|
else RAGStorage(
|
||||||
|
type="user",
|
||||||
|
allow_reset=True,
|
||||||
|
embedder_config=embedder_config,
|
||||||
|
crew=crew,
|
||||||
|
path=path,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
storage = Mem0Storage(type="user", crew=crew)
|
super().__init__(storage=storage)
|
||||||
super().__init__(storage)
|
self._memory_provider = memory_provider
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
value,
|
value: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
agent: Optional[str] = None,
|
agent: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: Change this function since we want to take care of the case where we save memories for the usr
|
"""
|
||||||
data = f"Remember the details about the user: {value}"
|
Save user memory data with appropriate formatting based on the storage provider.
|
||||||
super().save(data, metadata)
|
|
||||||
|
Args:
|
||||||
|
value: The data to save
|
||||||
|
metadata: Optional metadata to associate with the memory
|
||||||
|
agent: Optional agent name to associate with the memory
|
||||||
|
"""
|
||||||
|
if self._memory_provider == "mem0":
|
||||||
|
data = f"Remember the details about the user: {value}"
|
||||||
|
else:
|
||||||
|
data = value
|
||||||
|
super().save(data, metadata, agent)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 3,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
):
|
) -> List[Any]:
|
||||||
results = self.storage.search(
|
"""
|
||||||
|
Search for user memories that match the query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query
|
||||||
|
limit: Maximum number of results to return
|
||||||
|
score_threshold: Minimum similarity score for results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching memory items
|
||||||
|
"""
|
||||||
|
return self.storage.search(
|
||||||
query=query,
|
query=query,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
)
|
)
|
||||||
return results
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the user memory storage."""
|
||||||
|
try:
|
||||||
|
self.storage.reset()
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"An error occurred while resetting the user memory: {e}")
|
||||||
|
|||||||
58
tests/memory/user_memory_test.py
Normal file
58
tests/memory/user_memory_test.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
from unittest.mock import PropertyMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.memory.storage.rag_storage import RAGStorage
|
||||||
|
from crewai.memory.user.user_memory import UserMemory
|
||||||
|
|
||||||
|
|
||||||
|
@patch('crewai.memory.storage.mem0_storage.Mem0Storage')
|
||||||
|
@patch('crewai.memory.storage.mem0_storage.MemoryClient')
|
||||||
|
def test_user_memory_provider_selection(mock_memory_client, mock_mem0_storage):
|
||||||
|
"""Test that UserMemory selects the correct storage provider based on config."""
|
||||||
|
# Setup - Mock Mem0Storage to avoid API key requirement
|
||||||
|
mock_mem0_storage.return_value = mock_mem0_storage
|
||||||
|
|
||||||
|
# Test with mem0 provider
|
||||||
|
with patch('crewai.memory.user.user_memory.RAGStorage'):
|
||||||
|
# Create UserMemory with mem0 provider
|
||||||
|
memory_config = {"provider": "mem0"}
|
||||||
|
user_memory = UserMemory(memory_config=memory_config)
|
||||||
|
|
||||||
|
# Verify Mem0Storage was used
|
||||||
|
mock_mem0_storage.assert_called_once()
|
||||||
|
|
||||||
|
# Reset mocks
|
||||||
|
mock_mem0_storage.reset_mock()
|
||||||
|
|
||||||
|
# Test with default provider (RAGStorage)
|
||||||
|
with patch('crewai.memory.user.user_memory.RAGStorage', return_value=mock_mem0_storage) as mock_rag:
|
||||||
|
# Create UserMemory with no provider specified
|
||||||
|
user_memory = UserMemory()
|
||||||
|
|
||||||
|
# Verify RAGStorage was used
|
||||||
|
mock_rag.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch('crewai.memory.user.user_memory.UserMemory._memory_provider', new_callable=PropertyMock)
|
||||||
|
def test_user_memory_save_formatting(mock_memory_provider):
|
||||||
|
"""Test that UserMemory formats data correctly based on provider."""
|
||||||
|
# Test with mem0 provider
|
||||||
|
mock_memory_provider.return_value = "mem0"
|
||||||
|
with patch('crewai.memory.memory.Memory.save') as mock_save:
|
||||||
|
user_memory = UserMemory()
|
||||||
|
user_memory.save("test data")
|
||||||
|
|
||||||
|
# Verify data was formatted for mem0
|
||||||
|
args, _ = mock_save.call_args
|
||||||
|
assert "Remember the details about the user: test data" in args[0]
|
||||||
|
|
||||||
|
# Test with RAG provider
|
||||||
|
mock_memory_provider.return_value = None
|
||||||
|
with patch('crewai.memory.memory.Memory.save') as mock_save:
|
||||||
|
user_memory = UserMemory()
|
||||||
|
user_memory.save("test data")
|
||||||
|
|
||||||
|
# Verify data was not formatted
|
||||||
|
args, _ = mock_save.call_args
|
||||||
|
assert args[0] == "test data"
|
||||||
Reference in New Issue
Block a user