mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-22 15:28:30 +00:00
Compare commits
1 Commits
bugfix/add
...
devin/1744
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4748597667 |
@@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||||
from crewai.memory.memory import Memory
|
from crewai.memory.memory import Memory
|
||||||
from crewai.memory.storage.rag_storage import RAGStorage
|
from crewai.memory.storage.rag_storage import RAGStorage
|
||||||
@@ -38,7 +40,7 @@ class EntityMemory(Memory):
|
|||||||
)
|
)
|
||||||
super().__init__(storage)
|
super().__init__(storage)
|
||||||
|
|
||||||
def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
def save(self, item: EntityMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||||
"""Saves an entity item into the SQLite storage."""
|
"""Saves an entity item into the SQLite storage."""
|
||||||
if self.memory_provider == "mem0":
|
if self.memory_provider == "mem0":
|
||||||
data = f"""
|
data = f"""
|
||||||
@@ -49,7 +51,7 @@ class EntityMemory(Memory):
|
|||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
data = f"{item.name}({item.type}): {item.description}"
|
data = f"{item.name}({item.type}): {item.description}"
|
||||||
super().save(data, item.metadata)
|
super().save(data, item.metadata, custom_key=custom_key)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||||
from crewai.memory.memory import Memory
|
from crewai.memory.memory import Memory
|
||||||
@@ -19,9 +19,12 @@ class LongTermMemory(Memory):
|
|||||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||||
super().__init__(storage)
|
super().__init__(storage)
|
||||||
|
|
||||||
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
def save(self, item: LongTermMemoryItem, custom_key: Optional[str] = None) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"
|
||||||
metadata = item.metadata
|
metadata = item.metadata
|
||||||
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
|
metadata.update({"agent": item.agent, "expected_output": item.expected_output})
|
||||||
|
if custom_key:
|
||||||
|
metadata.update({"custom_key": custom_key})
|
||||||
|
|
||||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||||
task_description=item.task,
|
task_description=item.task,
|
||||||
score=metadata["quality"],
|
score=metadata["quality"],
|
||||||
@@ -29,8 +32,8 @@ class LongTermMemory(Memory):
|
|||||||
datetime=item.datetime,
|
datetime=item.datetime,
|
||||||
)
|
)
|
||||||
|
|
||||||
def search(self, task: str, latest_n: int = 3) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
def search(self, task: str, latest_n: int = 3, custom_key: Optional[str] = None) -> List[Dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||||
return self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
return self.storage.load(task, latest_n, custom_key) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
self.storage.reset()
|
self.storage.reset()
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ from crewai.memory.storage.rag_storage import RAGStorage
|
|||||||
|
|
||||||
class Memory:
|
class Memory:
|
||||||
"""
|
"""
|
||||||
Base class for memory, now supporting agent tags and generic metadata.
|
Base class for memory, now supporting agent tags, generic metadata, and custom keys.
|
||||||
|
|
||||||
|
Custom keys allow scoping memories to specific entities (users, accounts, sessions),
|
||||||
|
retrieving memories contextually, and preventing data leakage across logical boundaries.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, storage: RAGStorage):
|
def __init__(self, storage: RAGStorage):
|
||||||
@@ -16,10 +19,13 @@ class Memory:
|
|||||||
value: Any,
|
value: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
agent: Optional[str] = None,
|
agent: Optional[str] = None,
|
||||||
|
custom_key: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
metadata = metadata or {}
|
metadata = metadata or {}
|
||||||
if agent:
|
if agent:
|
||||||
metadata["agent"] = agent
|
metadata["agent"] = agent
|
||||||
|
if custom_key:
|
||||||
|
metadata["custom_key"] = custom_key
|
||||||
|
|
||||||
self.storage.save(value, metadata)
|
self.storage.save(value, metadata)
|
||||||
|
|
||||||
@@ -28,7 +34,12 @@ class Memory:
|
|||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 3,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
|
custom_key: Optional[str] = None,
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
|
filter_dict = None
|
||||||
|
if custom_key:
|
||||||
|
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||||
|
|
||||||
return self.storage.search(
|
return self.storage.search(
|
||||||
query=query, limit=limit, score_threshold=score_threshold
|
query=query, limit=limit, score_threshold=score_threshold, filter=filter_dict
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -46,22 +46,31 @@ class ShortTermMemory(Memory):
|
|||||||
value: Any,
|
value: Any,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
agent: Optional[str] = None,
|
agent: Optional[str] = None,
|
||||||
|
custom_key: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||||
if self.memory_provider == "mem0":
|
if self.memory_provider == "mem0":
|
||||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||||
|
|
||||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
super().save(value=item.data, metadata=item.metadata, agent=item.agent, custom_key=custom_key)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 3,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
|
custom_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
filter_dict = None
|
||||||
|
if custom_key:
|
||||||
|
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||||
|
|
||||||
return self.storage.search(
|
return self.storage.search(
|
||||||
query=query, limit=limit, score_threshold=score_threshold
|
query=query,
|
||||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
limit=limit,
|
||||||
|
score_threshold=score_threshold,
|
||||||
|
filter=filter_dict
|
||||||
|
)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -70,22 +70,31 @@ class LTMSQLiteStorage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self, task_description: str, latest_n: int
|
self, task_description: str, latest_n: int, custom_key: Optional[str] = None
|
||||||
) -> Optional[List[Dict[str, Any]]]:
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
"""Queries the LTM table by task description with error handling."""
|
"""Queries the LTM table by task description with error handling."""
|
||||||
try:
|
try:
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
|
||||||
f"""
|
query = """
|
||||||
SELECT metadata, datetime, score
|
SELECT metadata, datetime, score
|
||||||
FROM long_term_memories
|
FROM long_term_memories
|
||||||
WHERE task_description = ?
|
WHERE task_description = ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
params = [task_description]
|
||||||
|
|
||||||
|
if custom_key:
|
||||||
|
query += " AND json_extract(metadata, '$.custom_key') = ?"
|
||||||
|
params.append(custom_key)
|
||||||
|
|
||||||
|
query += f"""
|
||||||
ORDER BY datetime DESC, score ASC
|
ORDER BY datetime DESC, score ASC
|
||||||
LIMIT {latest_n}
|
LIMIT {latest_n}
|
||||||
""", # nosec
|
"""
|
||||||
(task_description,),
|
|
||||||
)
|
cursor.execute(query, params)
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
if rows:
|
if rows:
|
||||||
return [
|
return [
|
||||||
|
|||||||
@@ -120,7 +120,11 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with suppress_logging():
|
with suppress_logging():
|
||||||
response = self.collection.query(query_texts=query, n_results=limit)
|
response = self.collection.query(
|
||||||
|
query_texts=query,
|
||||||
|
n_results=limit,
|
||||||
|
where=filter
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i in range(len(response["ids"][0])):
|
for i in range(len(response["ids"][0])):
|
||||||
|
|||||||
@@ -26,20 +26,27 @@ class UserMemory(Memory):
|
|||||||
value,
|
value,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
agent: Optional[str] = None,
|
agent: Optional[str] = None,
|
||||||
|
custom_key: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: Change this function since we want to take care of the case where we save memories for the usr
|
# TODO: Change this function since we want to take care of the case where we save memories for the usr
|
||||||
data = f"Remember the details about the user: {value}"
|
data = f"Remember the details about the user: {value}"
|
||||||
super().save(data, metadata)
|
super().save(data, metadata, custom_key=custom_key)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
limit: int = 3,
|
limit: int = 3,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
|
custom_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
filter_dict = None
|
||||||
|
if custom_key:
|
||||||
|
filter_dict = {"custom_key": {"$eq": custom_key}}
|
||||||
|
|
||||||
results = self.storage.search(
|
results = self.storage.search(
|
||||||
query=query,
|
query=query,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
filter=filter_dict,
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|||||||
57
tests/memory/custom_key_memory_test.py
Normal file
57
tests/memory/custom_key_memory_test.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||||
|
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def short_term_memory():
|
||||||
|
"""Fixture to create a ShortTermMemory instance"""
|
||||||
|
agent = Agent(
|
||||||
|
role="Researcher",
|
||||||
|
goal="Search relevant data and provide results",
|
||||||
|
backstory="You are a researcher at a leading tech think tank.",
|
||||||
|
tools=[],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Perform a search on specific topics.",
|
||||||
|
expected_output="A list of relevant URLs based on the search query.",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_with_custom_key(short_term_memory):
|
||||||
|
"""Test that save method correctly passes custom_key to storage"""
|
||||||
|
with patch.object(short_term_memory.storage, 'save') as mock_save:
|
||||||
|
short_term_memory.save(
|
||||||
|
value="Test data",
|
||||||
|
metadata={"task": "test_task"},
|
||||||
|
agent="test_agent",
|
||||||
|
custom_key="user123",
|
||||||
|
)
|
||||||
|
|
||||||
|
called_args = mock_save.call_args[0]
|
||||||
|
called_kwargs = mock_save.call_args[1]
|
||||||
|
|
||||||
|
assert "custom_key" in called_args[1]
|
||||||
|
assert called_args[1]["custom_key"] == "user123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_with_custom_key(short_term_memory):
|
||||||
|
"""Test that search method correctly passes custom_key to storage"""
|
||||||
|
expected_results = [{"context": "Test data", "metadata": {"custom_key": "user123"}, "score": 0.95}]
|
||||||
|
|
||||||
|
with patch.object(short_term_memory.storage, 'search', return_value=expected_results) as mock_search:
|
||||||
|
results = short_term_memory.search("test query", custom_key="user123")
|
||||||
|
|
||||||
|
mock_search.assert_called_once()
|
||||||
|
filter_arg = mock_search.call_args[1].get('filter')
|
||||||
|
assert filter_arg == {"custom_key": {"$eq": "user123"}}
|
||||||
|
assert results == expected_results
|
||||||
Reference in New Issue
Block a user