fixing long temr memory interpolation

This commit is contained in:
João Moura
2024-04-07 14:55:35 -03:00
parent 755b3934a4
commit e4556040a8
4 changed files with 21 additions and 18 deletions

View File

@@ -81,9 +81,7 @@ class CrewAgentExecutor(AgentExecutor):
datetime=str(time.time()),
expected_output=self.task.expected_output,
metadata={
"suggestions": "\n".join(
[f"- {s}" for s in evaluation.suggestions]
),
"suggestions": evaluation.suggestions,
"quality": evaluation.quality,
},
)

View File

@@ -37,13 +37,18 @@ class ContextualMemory:
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points.
"""
ltm_results = self.ltm.search(task)
ltm_results = self.ltm.search(task, latest_n=2)
if not ltm_results:
return None
formatted_results = "\n".join(
[f"{result['metadata']['suggestions']}" for result in ltm_results]
)
formatted_results = list(set(formatted_results.split('\n')))
formatted_results = [
suggestion
for result in ltm_results
for suggestion in result["metadata"]["suggestions"]
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results])
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str:

View File

@@ -28,5 +28,5 @@ class LongTermMemory(Memory):
datetime=item.datetime,
)
def search(self, task: str) -> Dict[str, Any]:
return self.storage.load(task)
def search(self, task: str, latest_n: int) -> Dict[str, Any]:
return self.storage.load(task, latest_n)

View File

@@ -67,19 +67,19 @@ class LTMSQLiteStorage:
color="red",
)
def load(self, task_description: str) -> Dict[str, Any]:
def load(self, task_description: str, latest_n: int) -> Dict[str, Any]:
"""Queries the LTM table by task description with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
ORDER BY datetime DESC, score ASC
LIMIT 2
""",
f"""
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
ORDER BY datetime DESC, score ASC
LIMIT {latest_n}
""",
(task_description,),
)
rows = cursor.fetchall()