mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Fix issue #2242: Improve memory retrieval to prioritize recent conversation context
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -56,6 +56,11 @@ class ShortTermMemory(Memory):
|
|||||||
if self._memory_provider == "mem0":
|
if self._memory_provider == "mem0":
|
||||||
item.data = f"Remember the following insights from Agent run: {item.data}"
|
item.data = f"Remember the following insights from Agent run: {item.data}"
|
||||||
|
|
||||||
|
# Include timestamp in metadata
|
||||||
|
if item.metadata is None:
|
||||||
|
item.metadata = {}
|
||||||
|
item.metadata["timestamp"] = item.timestamp.isoformat()
|
||||||
|
|
||||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
@@ -7,7 +8,9 @@ class ShortTermMemoryItem:
|
|||||||
data: Any,
|
data: Any,
|
||||||
agent: Optional[str] = None,
|
agent: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
timestamp: Optional[datetime] = None,
|
||||||
):
|
):
|
||||||
self.data = data
|
self.data = data
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.metadata = metadata if metadata is not None else {}
|
self.metadata = metadata if metadata is not None else {}
|
||||||
|
self.timestamp = timestamp if timestamp is not None else datetime.now()
|
||||||
|
|||||||
@@ -114,13 +114,14 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
limit: int = 3,
|
limit: int = 3,
|
||||||
filter: Optional[dict] = None,
|
filter: Optional[dict] = None,
|
||||||
score_threshold: float = 0.35,
|
score_threshold: float = 0.35,
|
||||||
|
recency_weight: float = 0.3,
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
if not hasattr(self, "app"):
|
if not hasattr(self, "app"):
|
||||||
self._initialize_app()
|
self._initialize_app()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with suppress_logging():
|
with suppress_logging():
|
||||||
response = self.collection.query(query_texts=query, n_results=limit)
|
response = self.collection.query(query_texts=query, n_results=limit * 2) # Get more results to allow for recency filtering
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i in range(len(response["ids"][0])):
|
for i in range(len(response["ids"][0])):
|
||||||
@@ -130,10 +131,27 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
"context": response["documents"][0][i],
|
"context": response["documents"][0][i],
|
||||||
"score": response["distances"][0][i],
|
"score": response["distances"][0][i],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Apply recency boost if timestamp exists in metadata
|
||||||
|
if "timestamp" in result["metadata"]:
|
||||||
|
try:
|
||||||
|
from datetime import datetime
|
||||||
|
timestamp = datetime.fromisoformat(result["metadata"]["timestamp"])
|
||||||
|
now = datetime.now()
|
||||||
|
# Calculate recency factor (newer = higher score)
|
||||||
|
time_diff_seconds = (now - timestamp).total_seconds()
|
||||||
|
recency_factor = max(0, 1 - (time_diff_seconds / (24 * 60 * 60))) # Normalize to 1 day
|
||||||
|
# Adjust score with recency factor
|
||||||
|
result["score"] = result["score"] * (1 - recency_weight) + recency_factor * recency_weight
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass # If timestamp parsing fails, use original score
|
||||||
|
|
||||||
if result["score"] >= score_threshold:
|
if result["score"] >= score_threshold:
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
return results
|
# Sort by adjusted score (higher is better)
|
||||||
|
results.sort(key=lambda x: x["score"], reverse=True)
|
||||||
|
return results[:limit] # Return only the requested number of results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error during {self.type} search: {str(e)}")
|
logging.error(f"Error during {self.type} search: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|||||||
87
tests/memory/test_memory_topic_changes.py
Normal file
87
tests/memory/test_memory_topic_changes.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import time
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.agent import Agent
|
||||||
|
from crewai.crew import Crew
|
||||||
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||||
|
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||||
|
from crewai.memory.storage.rag_storage import RAGStorage
|
||||||
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def short_term_memory():
|
||||||
|
"""Fixture to create a ShortTermMemory instance"""
|
||||||
|
agent = Agent(
|
||||||
|
role="Tutor",
|
||||||
|
goal="Teach programming concepts",
|
||||||
|
backstory="You are a programming tutor helping students learn.",
|
||||||
|
tools=[],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Explain programming concepts to students.",
|
||||||
|
expected_output="Clear explanations of programming concepts.",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_prioritizes_recent_topic(short_term_memory):
|
||||||
|
"""Test that memory retrieval prioritizes the most recent topic in a conversation."""
|
||||||
|
# First topic: Python variables
|
||||||
|
topic1_data = "Variables in Python are dynamically typed. You can assign any value to a variable without declaring its type."
|
||||||
|
topic1_timestamp = datetime.now() - timedelta(minutes=10) # Older memory
|
||||||
|
|
||||||
|
# Second topic: Python abstract classes
|
||||||
|
topic2_data = "Abstract classes in Python are created using the ABC module. They cannot be instantiated and are used as a blueprint for other classes."
|
||||||
|
topic2_timestamp = datetime.now() # More recent memory
|
||||||
|
|
||||||
|
# Mock search results to simulate what would be returned by RAGStorage
|
||||||
|
mock_results = [
|
||||||
|
{
|
||||||
|
"id": "2",
|
||||||
|
"metadata": {
|
||||||
|
"agent": "Tutor",
|
||||||
|
"topic": "python_abstract_classes",
|
||||||
|
"timestamp": topic2_timestamp.isoformat()
|
||||||
|
},
|
||||||
|
"context": topic2_data,
|
||||||
|
"score": 0.85, # Higher score due to recency boost
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "1",
|
||||||
|
"metadata": {
|
||||||
|
"agent": "Tutor",
|
||||||
|
"topic": "python_variables",
|
||||||
|
"timestamp": topic1_timestamp.isoformat()
|
||||||
|
},
|
||||||
|
"context": topic1_data,
|
||||||
|
"score": 0.75, # Lower score due to being older
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock the search method to return our predefined results
|
||||||
|
with patch.object(RAGStorage, 'search', return_value=mock_results):
|
||||||
|
# Query that could match both topics but should prioritize the more recent one
|
||||||
|
query = "Can you give me another example of that?"
|
||||||
|
|
||||||
|
# Search with recency consideration
|
||||||
|
results = short_term_memory.search(query)
|
||||||
|
|
||||||
|
# Verify that the most recent topic (abstract classes) is prioritized
|
||||||
|
assert len(results) > 0, "No search results returned"
|
||||||
|
|
||||||
|
# The first result should be about abstract classes (the more recent topic)
|
||||||
|
assert "abstract classes" in results[0]["context"].lower(), "Recent topic (abstract classes) not prioritized"
|
||||||
|
|
||||||
|
# If there are multiple results, check if the older topic is also returned but with lower priority
|
||||||
|
if len(results) > 1:
|
||||||
|
assert "variables" in results[1]["context"].lower(), "Older topic should be second"
|
||||||
|
|
||||||
|
# Verify that the scores reflect the recency prioritization
|
||||||
|
assert results[0]["score"] > results[1]["score"], "Recent topic should have higher score"
|
||||||
Reference in New Issue
Block a user