mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 23:28:30 +00:00
Compare commits
1 Commits
devin/1742
...
devin/1744
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4748597667 |
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
@@ -38,7 +40,7 @@ class EntityMemory(Memory):
|
||||
)
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
def save(self, item: EntityMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
"""Saves an entity item into the SQLite storage."""
|
||||
if self.memory_provider == "mem0":
|
||||
data = f"""
|
||||
@@ -49,7 +51,7 @@ class EntityMemory(Memory):
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
super().save(data, item.metadata)
|
||||
super().save(data, item.metadata, custom_key=custom_key)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.memory import Memory
|
||||
@@ -19,9 +19,12 @@ class LongTermMemory(Memory):
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage)
|
||||
|
||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
def save(self, item: LongTermMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||
metadata = item.metadata
|
||||
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
|
||||
if custom_key:
|
||||
metadata.update({"custom_key": custom_key})
|
||||
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
@@ -29,8 +32,8 @@ class LongTermMemory(Memory):
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
def search(self, task: str, latest_n: int = 3, custom_key: Optional[str] = None) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
return self.storage.load(task, latest_n, custom_key) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
|
||||
def reset(self) -> None:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -5,7 +5,10 @@ from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
class Memory:
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
Base class for memory, now supporting agent tags, generic metadata, and custom keys.
|
||||
|
||||
Custom keys allow scoping memories to specific entities (users, accounts, sessions),
|
||||
retrieving memories contextually, and preventing data leakage across logical boundaries.
|
||||
"""
|
||||
|
||||
def __init__(self, storage: RAGStorage):
|
||||
@@ -16,10 +19,13 @@ class Memory:
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
if agent:
|
||||
metadata["agent"] = agent
|
||||
if custom_key:
|
||||
metadata["custom_key"] = custom_key
|
||||
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
@@ -28,7 +34,12 @@ class Memory:
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> List[Any]:
|
||||
filter_dict = None
|
||||
if custom_key:
|
||||
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
query=query, limit=limit, score_threshold=score_threshold, filter=filter_dict
|
||||
)
|
||||
|
||||
@@ -46,22 +46,31 @@ class ShortTermMemory(Memory):
|
||||
value: Any,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> None:
|
||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||
if self.memory_provider == "mem0":
|
||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent, custom_key=custom_key)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
custom_key: Optional[str] = None,
|
||||
):
|
||||
filter_dict = None
|
||||
if custom_key:
|
||||
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||
|
||||
return self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
filter=filter_dict
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
|
||||
@@ -70,22 +70,31 @@ class LTMSQLiteStorage:
|
||||
)
|
||||
|
||||
def load(
|
||||
self, task_description: str, latest_n: int
|
||||
self, task_description: str, latest_n: int, custom_key: Optional[str] = None
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Queries the LTM table by task description with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
f"""
|
||||
|
||||
query = """
|
||||
SELECT metadata, datetime, score
|
||||
FROM long_term_memories
|
||||
WHERE task_description = ?
|
||||
"""
|
||||
|
||||
params = [task_description]
|
||||
|
||||
if custom_key:
|
||||
query += " AND json_extract(metadata, '$.custom_key') = ?"
|
||||
params.append(custom_key)
|
||||
|
||||
query += f"""
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT {latest_n}
|
||||
""", # nosec
|
||||
(task_description,),
|
||||
)
|
||||
"""
|
||||
|
||||
cursor.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
if rows:
|
||||
return [
|
||||
|
||||
@@ -120,7 +120,11 @@ class RAGStorage(BaseRAGStorage):
|
||||
|
||||
try:
|
||||
with suppress_logging():
|
||||
response = self.collection.query(query_texts=query, n_results=limit)
|
||||
response = self.collection.query(
|
||||
query_texts=query,
|
||||
n_results=limit,
|
||||
where=filter
|
||||
)
|
||||
|
||||
results = []
|
||||
for i in range(len(response["ids"][0])):
|
||||
|
||||
@@ -26,20 +26,27 @@ class UserMemory(Memory):
|
||||
value,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
agent: Optional[str] = None,
|
||||
custom_key: Optional[str] = None,
|
||||
) -> None:
|
||||
# TODO: Change this function since we want to take care of the case where we save memories for the usr
|
||||
data = f"Remember the details about the user: {value}"
|
||||
super().save(data, metadata)
|
||||
super().save(data, metadata, custom_key=custom_key)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 3,
|
||||
score_threshold: float = 0.35,
|
||||
custom_key: Optional[str] = None,
|
||||
):
|
||||
filter_dict = None
|
||||
if custom_key:
|
||||
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||
|
||||
results = self.storage.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
filter=filter_dict,
|
||||
)
|
||||
return results
|
||||
|
||||
@@ -1,95 +1,42 @@
|
||||
from typing import Optional
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Color(Enum):
|
||||
"""Enum for text colors in terminal output."""
|
||||
PURPLE = "\033[95m"
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
BLUE = "\033[94m"
|
||||
YELLOW = "\033[93m"
|
||||
BOLD = "\033[1m"
|
||||
RESET = "\033[00m"
|
||||
|
||||
|
||||
class Printer:
|
||||
"""
|
||||
Utility class for printing formatted text to stdout.
|
||||
Uses direct stdout writing for compatibility with asynchronous environments.
|
||||
"""
|
||||
|
||||
def print(self, content: str, color: Optional[str] = None) -> None:
|
||||
"""
|
||||
Print content with optional color formatting.
|
||||
|
||||
Args:
|
||||
content: The text to print
|
||||
color: Optional color name (e.g., "purple", "bold_green")
|
||||
"""
|
||||
output = content
|
||||
def print(self, content: str, color: Optional[str] = None):
|
||||
if color == "purple":
|
||||
output = self._format_purple(content)
|
||||
self._print_purple(content)
|
||||
elif color == "red":
|
||||
output = self._format_red(content)
|
||||
self._print_red(content)
|
||||
elif color == "bold_green":
|
||||
output = self._format_bold_green(content)
|
||||
self._print_bold_green(content)
|
||||
elif color == "bold_purple":
|
||||
output = self._format_bold_purple(content)
|
||||
self._print_bold_purple(content)
|
||||
elif color == "bold_blue":
|
||||
output = self._format_bold_blue(content)
|
||||
self._print_bold_blue(content)
|
||||
elif color == "yellow":
|
||||
output = self._format_yellow(content)
|
||||
self._print_yellow(content)
|
||||
elif color == "bold_yellow":
|
||||
output = self._format_bold_yellow(content)
|
||||
|
||||
try:
|
||||
sys.stdout.write(f"{output}\n")
|
||||
sys.stdout.flush()
|
||||
except IOError:
|
||||
pass
|
||||
self._print_bold_yellow(content)
|
||||
else:
|
||||
print(content)
|
||||
|
||||
def _format_text(self, content: str, color: Color, bold: bool = False) -> str:
|
||||
"""
|
||||
Format text with color and optional bold styling.
|
||||
|
||||
Args:
|
||||
content: The text to format
|
||||
color: The color to apply
|
||||
bold: Whether to apply bold formatting
|
||||
|
||||
Returns:
|
||||
Formatted text string
|
||||
"""
|
||||
if bold:
|
||||
return f"{Color.BOLD.value}{color.value} {content}{Color.RESET.value}"
|
||||
return f"{color.value} {content}{Color.RESET.value}"
|
||||
def _print_bold_purple(self, content):
|
||||
print("\033[1m\033[95m {}\033[00m".format(content))
|
||||
|
||||
def _format_bold_purple(self, content: str) -> str:
|
||||
"""Format text as bold purple."""
|
||||
return self._format_text(content, Color.PURPLE, bold=True)
|
||||
def _print_bold_green(self, content):
|
||||
print("\033[1m\033[92m {}\033[00m".format(content))
|
||||
|
||||
def _format_bold_green(self, content: str) -> str:
|
||||
"""Format text as bold green."""
|
||||
return self._format_text(content, Color.GREEN, bold=True)
|
||||
def _print_purple(self, content):
|
||||
print("\033[95m {}\033[00m".format(content))
|
||||
|
||||
def _format_purple(self, content: str) -> str:
|
||||
"""Format text as purple."""
|
||||
return self._format_text(content, Color.PURPLE)
|
||||
def _print_red(self, content):
|
||||
print("\033[91m {}\033[00m".format(content))
|
||||
|
||||
def _format_red(self, content: str) -> str:
|
||||
"""Format text as red."""
|
||||
return self._format_text(content, Color.RED)
|
||||
def _print_bold_blue(self, content):
|
||||
print("\033[1m\033[94m {}\033[00m".format(content))
|
||||
|
||||
def _format_bold_blue(self, content: str) -> str:
|
||||
"""Format text as bold blue."""
|
||||
return self._format_text(content, Color.BLUE, bold=True)
|
||||
def _print_yellow(self, content):
|
||||
print("\033[93m {}\033[00m".format(content))
|
||||
|
||||
def _format_yellow(self, content: str) -> str:
|
||||
"""Format text as yellow."""
|
||||
return self._format_text(content, Color.YELLOW)
|
||||
|
||||
def _format_bold_yellow(self, content: str) -> str:
|
||||
"""Format text as bold yellow."""
|
||||
return self._format_text(content, Color.YELLOW, bold=True)
|
||||
def _print_bold_yellow(self, content):
|
||||
print("\033[1m\033[93m {}\033[00m".format(content))
|
||||
|
||||
57
tests/memory/custom_key_memory_test.py
Normal file
57
tests/memory/custom_key_memory_test.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_term_memory():
|
||||
"""Fixture to create a ShortTermMemory instance"""
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Search relevant data and provide results",
|
||||
backstory="You are a researcher at a leading tech think tank.",
|
||||
tools=[],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Perform a search on specific topics.",
|
||||
expected_output="A list of relevant URLs based on the search query.",
|
||||
agent=agent,
|
||||
)
|
||||
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
|
||||
|
||||
|
||||
def test_save_with_custom_key(short_term_memory):
|
||||
"""Test that save method correctly passes custom_key to storage"""
|
||||
with patch.object(short_term_memory.storage, 'save') as mock_save:
|
||||
short_term_memory.save(
|
||||
value="Test data",
|
||||
metadata={"task": "test_task"},
|
||||
agent="test_agent",
|
||||
custom_key="user123",
|
||||
)
|
||||
|
||||
called_args = mock_save.call_args[0]
|
||||
called_kwargs = mock_save.call_args[1]
|
||||
|
||||
assert "custom_key" in called_args[1]
|
||||
assert called_args[1]["custom_key"] == "user123"
|
||||
|
||||
|
||||
def test_search_with_custom_key(short_term_memory):
|
||||
"""Test that search method correctly passes custom_key to storage"""
|
||||
expected_results = [{"context": "Test data", "metadata": {"custom_key": "user123"}, "score": 0.95}]
|
||||
|
||||
with patch.object(short_term_memory.storage, 'search', return_value=expected_results) as mock_search:
|
||||
results = short_term_memory.search("test query", custom_key="user123")
|
||||
|
||||
mock_search.assert_called_once()
|
||||
filter_arg = mock_search.call_args[1].get('filter')
|
||||
assert filter_arg == {"custom_key": {"$eq": "user123"}}
|
||||
assert results == expected_results
|
||||
@@ -1,92 +0,0 @@
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
import asyncio
|
||||
import pytest
|
||||
from io import StringIO
|
||||
|
||||
try:
|
||||
import fastapi
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
try:
|
||||
from httpx import AsyncClient
|
||||
ASYNC_CLIENT_AVAILABLE = True
|
||||
except ImportError:
|
||||
ASYNC_CLIENT_AVAILABLE = False
|
||||
FASTAPI_AVAILABLE = True
|
||||
except ImportError:
|
||||
FASTAPI_AVAILABLE = False
|
||||
ASYNC_CLIENT_AVAILABLE = False
|
||||
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
@unittest.skipIf(not FASTAPI_AVAILABLE, "FastAPI not installed")
|
||||
class TestFastAPILogger(unittest.TestCase):
|
||||
"""Test suite for Logger class in FastAPI context."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment before each test."""
|
||||
if not FASTAPI_AVAILABLE:
|
||||
self.skipTest("FastAPI not installed")
|
||||
|
||||
self.app = FastAPI()
|
||||
self.logger = Logger(verbose=True)
|
||||
|
||||
@self.app.get("/")
|
||||
async def root():
|
||||
self.logger.log("info", "This is a test log message from FastAPI")
|
||||
return {"message": "Hello World"}
|
||||
|
||||
@self.app.get("/error")
|
||||
async def error_route():
|
||||
self.logger.log("error", "This is an error log message from FastAPI")
|
||||
return {"error": "Test error"}
|
||||
|
||||
self.client = TestClient(self.app)
|
||||
|
||||
self.output = StringIO()
|
||||
self.old_stdout = sys.stdout
|
||||
sys.stdout = self.output
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment after each test."""
|
||||
sys.stdout = self.old_stdout
|
||||
|
||||
def test_logger_in_fastapi_context(self):
|
||||
"""Test that logger works in FastAPI context."""
|
||||
response = self.client.get("/")
|
||||
|
||||
output = self.output.getvalue()
|
||||
self.assertIn("[INFO]: This is a test log message from FastAPI", output)
|
||||
self.assertIn("\n", output)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.json(), {"message": "Hello World"})
|
||||
|
||||
@pytest.mark.parametrize("route,log_level,expected_message", [
|
||||
("/", "info", "This is a test log message from FastAPI"),
|
||||
("/error", "error", "This is an error log message from FastAPI")
|
||||
])
|
||||
def test_multiple_routes(self, route, log_level, expected_message):
|
||||
"""Test logging from different routes with different log levels."""
|
||||
response = self.client.get(route)
|
||||
|
||||
output = self.output.getvalue()
|
||||
self.assertIn(f"[{log_level.upper()}]: {expected_message}", output)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@unittest.skipIf(not ASYNC_CLIENT_AVAILABLE, "AsyncClient not available")
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_logger_in_fastapi(self):
|
||||
"""Test logger in async context using AsyncClient."""
|
||||
self.output = StringIO()
|
||||
sys.stdout = self.output
|
||||
|
||||
async with AsyncClient(app=self.app, base_url="http://test") as ac:
|
||||
response = await ac.get("/")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
output = self.output.getvalue()
|
||||
self.assertIn("[INFO]: This is a test log message from FastAPI", output)
|
||||
@@ -1,88 +0,0 @@
|
||||
import sys
|
||||
import unittest
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
from io import StringIO
|
||||
import pytest
|
||||
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class TestLogger(unittest.TestCase):
|
||||
"""Test suite for the Logger class."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment before each test."""
|
||||
self.logger = Logger(verbose=True)
|
||||
self.output = StringIO()
|
||||
self.old_stdout = sys.stdout
|
||||
sys.stdout = self.output
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test environment after each test."""
|
||||
sys.stdout = self.old_stdout
|
||||
|
||||
def test_log_in_sync_context(self):
|
||||
"""Test logging in a regular synchronous context."""
|
||||
self.logger.log("info", "Test message")
|
||||
output = self.output.getvalue()
|
||||
self.assertIn("[INFO]: Test message", output)
|
||||
self.assertIn("\n", output)
|
||||
|
||||
@patch('sys.stdout.flush')
|
||||
def test_stdout_is_flushed(self, mock_flush):
|
||||
"""Test that stdout is properly flushed after writing."""
|
||||
self.logger.log("info", "Test message")
|
||||
mock_flush.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize("log_level,message", [
|
||||
("info", "Info message"),
|
||||
("error", "Error message"),
|
||||
("warning", "Warning message"),
|
||||
("debug", "Debug message")
|
||||
])
|
||||
def test_multiple_log_levels(self, log_level, message):
|
||||
"""Test logging with different log levels."""
|
||||
self.logger.log(log_level, message)
|
||||
output = self.output.getvalue()
|
||||
self.assertIn(f"[{log_level.upper()}]: {message}", output)
|
||||
|
||||
def test_thread_safety(self):
|
||||
"""Test that logger is thread-safe."""
|
||||
messages = []
|
||||
for i in range(10):
|
||||
messages.append(f"Message {i}")
|
||||
|
||||
threads = []
|
||||
for message in messages:
|
||||
thread = threading.Thread(
|
||||
target=lambda msg: self.logger.log("info", msg),
|
||||
args=(message,)
|
||||
)
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
output = self.output.getvalue()
|
||||
for message in messages:
|
||||
self.assertIn(message, output)
|
||||
|
||||
|
||||
class TestFastAPICompatibility(unittest.TestCase):
|
||||
"""Test compatibility with FastAPI."""
|
||||
|
||||
def test_import_in_fastapi(self):
|
||||
"""Test that logger can be imported in a FastAPI context."""
|
||||
try:
|
||||
import fastapi
|
||||
from crewai.utilities.logger import Logger
|
||||
logger = Logger(verbose=True)
|
||||
self.assertTrue(True)
|
||||
except ImportError:
|
||||
self.skipTest("FastAPI not installed")
|
||||
except Exception as e:
|
||||
self.fail(f"Unexpected error: {e}")
|
||||
Reference in New Issue
Block a user