fix: add type annotations to contextual_memory.py

This commit is contained in:
Greyson LaLonde
2025-09-05 09:57:02 -04:00
parent a414e7f2a7
commit e93d597721

View File

@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Optional from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional
from crewai.memory import ( from crewai.memory import (
EntityMemory, EntityMemory,
@@ -19,8 +21,8 @@ class ContextualMemory:
ltm: LongTermMemory, ltm: LongTermMemory,
em: EntityMemory, em: EntityMemory,
exm: ExternalMemory, exm: ExternalMemory,
agent: Optional["Agent"] = None, agent: Optional[Agent] = None,
task: Optional["Task"] = None, task: Optional[Task] = None,
): ):
self.stm = stm self.stm = stm
self.ltm = ltm self.ltm = ltm
@@ -42,7 +44,7 @@ class ContextualMemory:
self.exm.agent = self.agent self.exm.agent = self.agent
self.exm.task = self.task self.exm.task = self.task
def build_context_for_task(self, task, context) -> str: def build_context_for_task(self, task: Task, context: str) -> str:
""" """
Automatically builds a minimal, highly relevant set of contextual information Automatically builds a minimal, highly relevant set of contextual information
for a given task. for a given task.
@@ -52,14 +54,14 @@ class ContextualMemory:
if query == "": if query == "":
return "" return ""
context = [] context_parts = []
context.append(self._fetch_ltm_context(task.description)) context_parts.append(self._fetch_ltm_context(task.description))
context.append(self._fetch_stm_context(query)) context_parts.append(self._fetch_stm_context(query))
context.append(self._fetch_entity_context(query)) context_parts.append(self._fetch_entity_context(query))
context.append(self._fetch_external_context(query)) context_parts.append(self._fetch_external_context(query))
return "\n".join(filter(None, context)) return "\n".join(filter(None, context_parts))
def _fetch_stm_context(self, query) -> str: def _fetch_stm_context(self, query: str) -> str:
""" """
Fetches recent relevant insights from STM related to the task's description and expected_output, Fetches recent relevant insights from STM related to the task's description and expected_output,
formatted as bullet points. formatted as bullet points.
@@ -74,7 +76,7 @@ class ContextualMemory:
) )
return f"Recent Insights:\n{formatted_results}" if stm_results else "" return f"Recent Insights:\n{formatted_results}" if stm_results else ""
def _fetch_ltm_context(self, task) -> Optional[str]: def _fetch_ltm_context(self, task: str) -> Optional[str]:
""" """
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output, Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
formatted as bullet points. formatted as bullet points.
@@ -90,14 +92,16 @@ class ContextualMemory:
formatted_results = [ formatted_results = [
suggestion suggestion
for result in ltm_results for result in ltm_results
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" for suggestion in result["metadata"]["suggestions"]
] ]
formatted_results = list(dict.fromkeys(formatted_results)) formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]") formatted_results_str = "\n".join(
[f"- {result}" for result in formatted_results]
)
return f"Historical Data:\n{formatted_results}" if ltm_results else "" return f"Historical Data:\n{formatted_results_str}" if ltm_results else ""
def _fetch_entity_context(self, query) -> str: def _fetch_entity_context(self, query: str) -> str:
""" """
Fetches relevant entity information from Entity Memory related to the task's description and expected_output, Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
formatted as bullet points. formatted as bullet points.
@@ -107,7 +111,7 @@ class ContextualMemory:
em_results = self.em.search(query) em_results = self.em.search(query)
formatted_results = "\n".join( formatted_results = "\n".join(
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice" [f"- {result['context']}" for result in em_results]
) )
return f"Entities:\n{formatted_results}" if em_results else "" return f"Entities:\n{formatted_results}" if em_results else ""