Compare commits

..

1 Commits

Author SHA1 Message Date
Devin AI
4748597667 Add support for memory distinguished by custom key (resolves #2584)
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-11 07:56:30 +00:00
10 changed files with 121 additions and 109 deletions

View File

@@ -112,8 +112,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
try:
while not isinstance(formatted_answer, AgentFinish):
if not self.request_within_rpm_limit or self.request_within_rpm_limit():
self._check_context_length_before_call()
answer = self.llm.call(
self.messages,
callbacks=self.callbacks,
@@ -329,19 +327,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
]
def _check_context_length_before_call(self) -> None:
total_chars = sum(len(msg.get("content", "")) for msg in self.messages)
estimated_tokens = total_chars // 4
context_window_size = self.llm.get_context_window_size()
if estimated_tokens > context_window_size:
self._printer.print(
content=f"Estimated token count ({estimated_tokens}) exceeds context window ({context_window_size}). Handling proactively.",
color="yellow",
)
self._handle_context_length()
def _handle_context_length(self) -> None:
if self.respect_context_window:
self._printer.print(

View File

@@ -1,3 +1,5 @@
from typing import Optional
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.memory import Memory
from crewai.memory.storage.rag_storage import RAGStorage
@@ -38,7 +40,7 @@ class EntityMemory(Memory):
)
super().__init__(storage)
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
def save(self, item: EntityMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
"""Saves an entity item into the SQLite storage."""
if self.memory_provider == "mem0":
data = f"""
@@ -49,7 +51,7 @@ class EntityMemory(Memory):
"""
else:
data = f"{item.name}({item.type}): {item.description}"
super().save(data, item.metadata)
super().save(data, item.metadata, custom_key=custom_key)
def reset(self) -> None:
try:

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.memory import Memory
@@ -19,9 +19,12 @@ class LongTermMemory(Memory):
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage)
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
def save(self, item: LongTermMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
metadata = item.metadata
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
if custom_key:
metadata.update({"custom_key": custom_key})
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
task_description=item.task,
score=metadata["quality"],
@@ -29,8 +32,8 @@ class LongTermMemory(Memory):
datetime=item.datetime,
)
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
def search(self, task: str, latest_n: int = 3, custom_key: Optional[str] = None) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
return self.storage.load(task, latest_n, custom_key) # type: ignore # BUG?: "Storage" has no attribute "load"
def reset(self) -> None:
self.storage.reset()

View File

@@ -5,7 +5,10 @@ from crewai.memory.storage.rag_storage import RAGStorage
class Memory:
"""
Base class for memory, now supporting agent tags and generic metadata.
Base class for memory, now supporting agent tags, generic metadata, and custom keys.
Custom keys allow scoping memories to specific entities (users, accounts, sessions),
retrieving memories contextually, and preventing data leakage across logical boundaries.
"""
def __init__(self, storage: RAGStorage):
@@ -16,10 +19,13 @@ class Memory:
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
metadata = metadata or {}
if agent:
metadata["agent"] = agent
if custom_key:
metadata["custom_key"] = custom_key
self.storage.save(value, metadata)
@@ -28,7 +34,12 @@ class Memory:
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
) -> List[Any]:
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
query=query, limit=limit, score_threshold=score_threshold, filter=filter_dict
)

View File

@@ -46,22 +46,31 @@ class ShortTermMemory(Memory):
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
if self.memory_provider == "mem0":
item.data = f"Remember the following insights from Agent run: {item.data}"
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
super().save(value=item.data, metadata=item.metadata, agent=item.agent, custom_key=custom_key)
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
):
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
return self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
query=query,
limit=limit,
score_threshold=score_threshold,
filter=filter_dict
)
def reset(self) -> None:
try:

View File

@@ -70,22 +70,31 @@ class LTMSQLiteStorage:
)
def load(
self, task_description: str, latest_n: int
self, task_description: str, latest_n: int, custom_key: Optional[str] = None
) -> Optional[List[Dict[str, Any]]]:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
f"""
query = """
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
"""
params = [task_description]
if custom_key:
query += " AND json_extract(metadata, '$.custom_key') = ?"
params.append(custom_key)
query += f"""
ORDER BY datetime DESC, score ASC
LIMIT {latest_n}
""", # nosec
(task_description,),
)
"""
cursor.execute(query, params)
rows = cursor.fetchall()
if rows:
return [

View File

@@ -120,7 +120,11 @@ class RAGStorage(BaseRAGStorage):
try:
with suppress_logging():
response = self.collection.query(query_texts=query, n_results=limit)
response = self.collection.query(
query_texts=query,
n_results=limit,
where=filter
)
results = []
for i in range(len(response["ids"][0])):

View File

@@ -26,20 +26,27 @@ class UserMemory(Memory):
value,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
custom_key: Optional[str] = None,
) -> None:
# TODO: Change this function since we want to take care of the case where we save memories for the usr
data = f"Remember the details about the user: {value}"
super().save(data, metadata)
super().save(data, metadata, custom_key=custom_key)
def search(
self,
query: str,
limit: int = 3,
score_threshold: float = 0.35,
custom_key: Optional[str] = None,
):
filter_dict = None
if custom_key:
filter_dict = {"custom_key": {"$eq": custom_key}}
results = self.storage.search(
query=query,
limit=limit,
score_threshold=score_threshold,
filter=filter_dict,
)
return results

View File

@@ -1625,78 +1625,3 @@ def test_agent_with_knowledge_sources():
# Assert that the agent provides the correct information
assert "red" in result.raw.lower()
def test_proactive_context_length_handling_prevents_empty_response():
"""Test that proactive context length checking prevents empty LLM responses."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
sliding_context_window=True,
)
long_input = "This is a very long input that should exceed the context window. " * 1000
with patch.object(agent.llm, 'get_context_window_size', return_value=100):
with patch.object(agent.agent_executor, '_handle_context_length') as mock_handle:
with patch.object(agent.llm, 'call', return_value="Proper response after summarization"):
agent.agent_executor.messages = [
{"role": "user", "content": long_input}
]
task = Task(
description="Process this long input",
expected_output="A response",
agent=agent,
)
result = agent.execute_task(task)
mock_handle.assert_called()
assert result and result.strip() != ""
def test_proactive_context_length_handling_with_no_summarization():
"""Test proactive context length checking when summarization is disabled."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
sliding_context_window=False,
)
long_input = "This is a very long input. " * 1000
with patch.object(agent.llm, 'get_context_window_size', return_value=100):
agent.agent_executor.messages = [
{"role": "user", "content": long_input}
]
with pytest.raises(SystemExit):
agent.agent_executor._check_context_length_before_call()
def test_context_length_estimation():
"""Test the token estimation logic."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
)
agent.agent_executor.messages = [
{"role": "user", "content": "Short message"},
{"role": "assistant", "content": "Another short message"},
]
with patch.object(agent.llm, 'get_context_window_size', return_value=10):
with patch.object(agent.agent_executor, '_handle_context_length') as mock_handle:
agent.agent_executor._check_context_length_before_call()
mock_handle.assert_not_called()
with patch.object(agent.llm, 'get_context_window_size', return_value=5):
with patch.object(agent.agent_executor, '_handle_context_length') as mock_handle:
agent.agent_executor._check_context_length_before_call()
mock_handle.assert_called()

View File

@@ -0,0 +1,57 @@
import pytest
from unittest.mock import patch, MagicMock
from crewai.memory.short_term.short_term_memory import ShortTermMemory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
@pytest.fixture
def short_term_memory():
"""Fixture to create a ShortTermMemory instance"""
agent = Agent(
role="Researcher",
goal="Search relevant data and provide results",
backstory="You are a researcher at a leading tech think tank.",
tools=[],
verbose=True,
)
task = Task(
description="Perform a search on specific topics.",
expected_output="A list of relevant URLs based on the search query.",
agent=agent,
)
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
def test_save_with_custom_key(short_term_memory):
"""Test that save method correctly passes custom_key to storage"""
with patch.object(short_term_memory.storage, 'save') as mock_save:
short_term_memory.save(
value="Test data",
metadata={"task": "test_task"},
agent="test_agent",
custom_key="user123",
)
called_args = mock_save.call_args[0]
called_kwargs = mock_save.call_args[1]
assert "custom_key" in called_args[1]
assert called_args[1]["custom_key"] == "user123"
def test_search_with_custom_key(short_term_memory):
"""Test that search method correctly passes custom_key to storage"""
expected_results = [{"context": "Test data", "metadata": {"custom_key": "user123"}, "score": 0.95}]
with patch.object(short_term_memory.storage, 'search', return_value=expected_results) as mock_search:
results = short_term_memory.search("test query", custom_key="user123")
mock_search.assert_called_once()
filter_arg = mock_search.call_args[1].get('filter')
assert filter_arg == {"custom_key": {"$eq": "user123"}}
assert results == expected_results