Compare commits

..

1 Commits

Author SHA1 Message Date
Devin AI
4748597667 Add support for memory distinguished by custom key (resolves #2584)
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-11 07:56:30 +00:00
11 changed files with 145 additions and 276 deletions

View File

@@ -1,3 +1,5 @@
from typing import Optional
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
@@ -38,7 +40,7 @@ class EntityMemory(Memory):
)
super().__init__(storage)
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
def save(self, item: EntityMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
"""Saves an entity item into the SQLite storage."""
if self.memory_provider == "mem0":
data = f"""
@@ -49,7 +51,7 @@ class EntityMemory(Memory):
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
super().save(data, item.metadata, custom_key=custom_key)
def reset(self) -> None:
try:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
@@ -19,9 +19,12 @@ class LongTermMemory(Memory):
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage)
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
def save(self, item: LongTermMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
metadata = item.metadata
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
if custom_key:
metadata.update({"custom_key": custom_key})
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
task_description=item.task,
score=metadata["quality"],
@@ -29,8 +32,8 @@ class LongTermMemory(Memory):
datetime=item.datetime,
)
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def search(self, task: str, latest_n: int = 3, custom_key: Optional[str] = None) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n, custom_key) # type: ignore # BUG?: "Storage" has no attribute "load"
def reset(self) -> None:
self.storage.reset()

View File

@@ -5,7 +5,10 @@ from crewai.memory.storage.rag_storage import RAGStorage
class Memory:
"""
Base class for memory, now supporting agent tags and generic metadata.
Base class for memory, now supporting agent tags, generic metadata, and custom keys.
Custom keys allow scoping memories to specific entities (users, accounts, sessions),
retrieving memories contextually, and preventing data leakage across logical boundaries.
"""
def __init__(self, storage: RAGStorage):
@@ -16,10 +19,13 @@ class Memory:
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
metadata = metadata or {}
if agent:
metadata["agent"] = agent
if custom_key:
metadata["custom_key"] = custom_key
self.storage.save(value, metadata)
@@ -28,7 +34,12 @@ class Memory:
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
) -> List[Any]:
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
query=query, limit=limit, score_threshold=score_threshold, filter=filter_dict
)

View File

@@ -46,22 +46,31 @@ class ShortTermMemory(Memory):
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
if self.memory_provider == "mem0":
item.data = f"Remember the following insights from Agent run: {item.data}"
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
super().save(value=item.data, metadata=item.metadata, agent=item.agent, custom_key=custom_key)
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
):
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
query=query,
limit=limit,
score_threshold=score_threshold,
filter=filter_dict
)
def reset(self) -> None:
try:

View File

@@ -70,22 +70,31 @@ class LTMSQLiteStorage:
)
def load(
self, task_description: str, latest_n: int
self, task_description: str, latest_n: int, custom_key: Optional[str] = None
) -> Optional[List[Dict[str, Any]]]:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
f"""
query = """
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
"""
params = [task_description]
if custom_key:
query += " AND json_extract(metadata, '$.custom_key') = ?"
params.append(custom_key)
query += f"""
ORDER BY datetime DESC, score ASC
LIMIT {latest_n}
""", # nosec
(task_description,),
)
"""
cursor.execute(query, params)
rows = cursor.fetchall()
if rows:
return [

View File

@@ -120,7 +120,11 @@ class RAGStorage(BaseRAGStorage):
try:
with suppress_logging():
response = self.collection.query(query_texts=query, n_results=limit)
response = self.collection.query(
query_texts=query,
n_results=limit,
where=filter
)
results = []
for i in range(len(response["ids"][0])):

View File

@@ -26,20 +26,27 @@ class UserMemory(Memory):
value,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
# TODO: Change this function since we want to take care of the case where we save memories for the usr
data = f"Remember the details about the user: {value}"
super().save(data, metadata)
super().save(data, metadata, custom_key=custom_key)
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
):
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
results = self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
filter=filter_dict,
)
return results

View File

@@ -1,95 +1,42 @@
from typing import Optional
import sys
from enum import Enum
class Color(Enum):
"""Enum for text colors in terminal output."""
PURPLE = "\033[95m"
RED = "\033[91m"
GREEN = "\033[92m"
BLUE = "\033[94m"
YELLOW = "\033[93m"
BOLD = "\033[1m"
RESET = "\033[00m"
class Printer:
"""
Utility class for printing formatted text to stdout.
Uses direct stdout writing for compatibility with asynchronous environments.
"""
def print(self, content: str, color: Optional[str] = None) -> None:
"""
Print content with optional color formatting.
Args:
content: The text to print
color: Optional color name (e.g., "purple", "bold_green")
"""
output = content
def print(self, content: str, color: Optional[str] = None):
if color == "purple":
output = self._format_purple(content)
self._print_purple(content)
elif color == "red":
output = self._format_red(content)
self._print_red(content)
elif color == "bold_green":
output = self._format_bold_green(content)
self._print_bold_green(content)
elif color == "bold_purple":
output = self._format_bold_purple(content)
self._print_bold_purple(content)
elif color == "bold_blue":
output = self._format_bold_blue(content)
self._print_bold_blue(content)
elif color == "yellow":
output = self._format_yellow(content)
self._print_yellow(content)
elif color == "bold_yellow":
output = self._format_bold_yellow(content)
try:
sys.stdout.write(f"{output}\n")
sys.stdout.flush()
except IOError:
pass
self._print_bold_yellow(content)
else:
print(content)
def _format_text(self, content: str, color: Color, bold: bool = False) -> str:
"""
Format text with color and optional bold styling.
Args:
content: The text to format
color: The color to apply
bold: Whether to apply bold formatting
Returns:
Formatted text string
"""
if bold:
return f"{Color.BOLD.value}{color.value} {content}{Color.RESET.value}"
return f"{color.value} {content}{Color.RESET.value}"
def _print_bold_purple(self, content):
print("\033[1m\033[95m {}\033[00m".format(content))
def _format_bold_purple(self, content: str) -> str:
"""Format text as bold purple."""
return self._format_text(content, Color.PURPLE, bold=True)
def _print_bold_green(self, content):
print("\033[1m\033[92m {}\033[00m".format(content))
def _format_bold_green(self, content: str) -> str:
"""Format text as bold green."""
return self._format_text(content, Color.GREEN, bold=True)
def _print_purple(self, content):
print("\033[95m {}\033[00m".format(content))
def _format_purple(self, content: str) -> str:
"""Format text as purple."""
return self._format_text(content, Color.PURPLE)
def _print_red(self, content):
print("\033[91m {}\033[00m".format(content))
def _format_red(self, content: str) -> str:
"""Format text as red."""
return self._format_text(content, Color.RED)
def _print_bold_blue(self, content):
print("\033[1m\033[94m {}\033[00m".format(content))
def _format_bold_blue(self, content: str) -> str:
"""Format text as bold blue."""
return self._format_text(content, Color.BLUE, bold=True)
def _print_yellow(self, content):
print("\033[93m {}\033[00m".format(content))
def _format_yellow(self, content: str) -> str:
"""Format text as yellow."""
return self._format_text(content, Color.YELLOW)
def _format_bold_yellow(self, content: str) -> str:
"""Format text as bold yellow."""
return self._format_text(content, Color.YELLOW, bold=True)
def _print_bold_yellow(self, content):
print("\033[1m\033[93m {}\033[00m".format(content))

View File

@@ -0,0 +1,57 @@
import pytest
from unittest.mock import patch, MagicMock
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
@pytest.fixture
def short_term_memory():
"""Fixture to create a ShortTermMemory instance"""
agent = Agent(
role="Researcher",
goal="Search relevant data and provide results",
backstory="You are a researcher at a leading tech think tank.",
tools=[],
verbose=True,
)
task = Task(
description="Perform a search on specific topics.",
expected_output="A list of relevant URLs based on the search query.",
agent=agent,
)
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
def test_save_with_custom_key(short_term_memory):
"""Test that save method correctly passes custom_key to storage"""
with patch.object(short_term_memory.storage, 'save') as mock_save:
short_term_memory.save(
value="Test data",
metadata={"task": "test_task"},
agent="test_agent",
custom_key="user123",
)
called_args = mock_save.call_args[0]
called_kwargs = mock_save.call_args[1]
assert "custom_key" in called_args[1]
assert called_args[1]["custom_key"] == "user123"
def test_search_with_custom_key(short_term_memory):
"""Test that search method correctly passes custom_key to storage"""
expected_results = [{"context": "Test data", "metadata": {"custom_key": "user123"}, "score": 0.95}]
with patch.object(short_term_memory.storage, 'search', return_value=expected_results) as mock_search:
results = short_term_memory.search("test query", custom_key="user123")
mock_search.assert_called_once()
filter_arg = mock_search.call_args[1].get('filter')
assert filter_arg == {"custom_key": {"$eq": "user123"}}
assert results == expected_results

View File

@@ -1,92 +0,0 @@
import sys
import unittest
from unittest.mock import patch
import asyncio
import pytest
from io import StringIO
try:
import fastapi
from fastapi import FastAPI
from fastapi.testclient import TestClient
try:
from httpx import AsyncClient
ASYNC_CLIENT_AVAILABLE = True
except ImportError:
ASYNC_CLIENT_AVAILABLE = False
FASTAPI_AVAILABLE = True
except ImportError:
FASTAPI_AVAILABLE = False
ASYNC_CLIENT_AVAILABLE = False
from crewai.utilities.logger import Logger
@unittest.skipIf(not FASTAPI_AVAILABLE, "FastAPI not installed")
class TestFastAPILogger(unittest.TestCase):
"""Test suite for Logger class in FastAPI context."""
def setUp(self):
"""Set up test environment before each test."""
if not FASTAPI_AVAILABLE:
self.skipTest("FastAPI not installed")
self.app = FastAPI()
self.logger = Logger(verbose=True)
@self.app.get("/")
async def root():
self.logger.log("info", "This is a test log message from FastAPI")
return {"message": "Hello World"}
@self.app.get("/error")
async def error_route():
self.logger.log("error", "This is an error log message from FastAPI")
return {"error": "Test error"}
self.client = TestClient(self.app)
self.output = StringIO()
self.old_stdout = sys.stdout
sys.stdout = self.output
def tearDown(self):
"""Clean up test environment after each test."""
sys.stdout = self.old_stdout
def test_logger_in_fastapi_context(self):
"""Test that logger works in FastAPI context."""
response = self.client.get("/")
output = self.output.getvalue()
self.assertIn("[INFO]: This is a test log message from FastAPI", output)
self.assertIn("\n", output)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"message": "Hello World"})
@pytest.mark.parametrize("route,log_level,expected_message", [
("/", "info", "This is a test log message from FastAPI"),
("/error", "error", "This is an error log message from FastAPI")
])
def test_multiple_routes(self, route, log_level, expected_message):
"""Test logging from different routes with different log levels."""
response = self.client.get(route)
output = self.output.getvalue()
self.assertIn(f"[{log_level.upper()}]: {expected_message}", output)
self.assertEqual(response.status_code, 200)
@unittest.skipIf(not ASYNC_CLIENT_AVAILABLE, "AsyncClient not available")
@pytest.mark.asyncio
async def test_async_logger_in_fastapi(self):
"""Test logger in async context using AsyncClient."""
self.output = StringIO()
sys.stdout = self.output
async with AsyncClient(app=self.app, base_url="http://test") as ac:
response = await ac.get("/")
self.assertEqual(response.status_code, 200)
output = self.output.getvalue()
self.assertIn("[INFO]: This is a test log message from FastAPI", output)

View File

@@ -1,88 +0,0 @@
import sys
import unittest
import threading
from unittest.mock import patch
from io import StringIO
import pytest
from crewai.utilities.logger import Logger
class TestLogger(unittest.TestCase):
"""Test suite for the Logger class."""
def setUp(self):
"""Set up test environment before each test."""
self.logger = Logger(verbose=True)
self.output = StringIO()
self.old_stdout = sys.stdout
sys.stdout = self.output
def tearDown(self):
"""Clean up test environment after each test."""
sys.stdout = self.old_stdout
def test_log_in_sync_context(self):
"""Test logging in a regular synchronous context."""
self.logger.log("info", "Test message")
output = self.output.getvalue()
self.assertIn("[INFO]: Test message", output)
self.assertIn("\n", output)
@patch('sys.stdout.flush')
def test_stdout_is_flushed(self, mock_flush):
"""Test that stdout is properly flushed after writing."""
self.logger.log("info", "Test message")
mock_flush.assert_called_once()
@pytest.mark.parametrize("log_level,message", [
("info", "Info message"),
("error", "Error message"),
("warning", "Warning message"),
("debug", "Debug message")
])
def test_multiple_log_levels(self, log_level, message):
"""Test logging with different log levels."""
self.logger.log(log_level, message)
output = self.output.getvalue()
self.assertIn(f"[{log_level.upper()}]: {message}", output)
def test_thread_safety(self):
"""Test that logger is thread-safe."""
messages = []
for i in range(10):
messages.append(f"Message {i}")
threads = []
for message in messages:
thread = threading.Thread(
target=lambda msg: self.logger.log("info", msg),
args=(message,)
)
threads.append(thread)
for thread in threads:
thread.start()
for thread in threads:
thread.join()
output = self.output.getvalue()
for message in messages:
self.assertIn(message, output)
class TestFastAPICompatibility(unittest.TestCase):
"""Test compatibility with FastAPI."""
def test_import_in_fastapi(self):
"""Test that logger can be imported in a FastAPI context."""
try:
import fastapi
from crewai.utilities.logger import Logger
logger = Logger(verbose=True)
self.assertTrue(True)
except ImportError:
self.skipTest("FastAPI not installed")
except Exception as e:
self.fail(f"Unexpected error: {e}")