fixing long temr memory interpolation

This commit is contained in:
João Moura
2024-04-07 14:55:35 -03:00
parent 755b3934a4
commit e4556040a8
4 changed files with 21 additions and 18 deletions

View File

@@ -81,9 +81,7 @@ class CrewAgentExecutor(AgentExecutor):
datetime=str(time.time()), datetime=str(time.time()),
expected_output=self.task.expected_output, expected_output=self.task.expected_output,
metadata={ metadata={
"suggestions": "\n".join( "suggestions": evaluation.suggestions,
[f"- {s}" for s in evaluation.suggestions]
),
"quality": evaluation.quality, "quality": evaluation.quality,
}, },
) )

View File

@@ -37,13 +37,18 @@ class ContextualMemory:
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output, Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points. formatted as bullet points.
""" """
ltm_results = self.ltm.search(task) ltm_results = self.ltm.search(task, latest_n=2)
if not ltm_results: if not ltm_results:
return None return None
formatted_results = "\n".join(
[f"{result['metadata']['suggestions']}" for result in ltm_results] formatted_results = [
) suggestion
formatted_results = list(set(formatted_results.split('\n'))) for result in ltm_results
for suggestion in result["metadata"]["suggestions"]
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results])
return f"Historical Data:\n{formatted_results}" if ltm_results else "" return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str: def _fetch_entity_context(self, query) -> str:

View File

@@ -28,5 +28,5 @@ class LongTermMemory(Memory):
datetime=item.datetime, datetime=item.datetime,
) )
def search(self, task: str) -> Dict[str, Any]: def search(self, task: str, latest_n: int) -> Dict[str, Any]:
return self.storage.load(task) return self.storage.load(task, latest_n)

View File

@@ -67,19 +67,19 @@ class LTMSQLiteStorage:
color="red", color="red",
) )
def load(self, task_description: str) -> Dict[str, Any]: def load(self, task_description: str, latest_n: int) -> Dict[str, Any]:
"""Queries the LTM table by task description with error handling.""" """Queries the LTM table by task description with error handling."""
try: try:
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
""" f"""
SELECT metadata, datetime, score SELECT metadata, datetime, score
FROM long_term_memories FROM long_term_memories
WHERE task_description = ? WHERE task_description = ?
ORDER BY datetime DESC, score ASC ORDER BY datetime DESC, score ASC
LIMIT 2 LIMIT {latest_n}
""", """,
(task_description,), (task_description,),
) )
rows = cursor.fetchall() rows = cursor.fetchall()