mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
fix: add type annotations to contextual_memory.py
This commit is contained in:
@@ -1,4 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from crewai.memory import (
|
from crewai.memory import (
|
||||||
EntityMemory,
|
EntityMemory,
|
||||||
@@ -19,8 +21,8 @@ class ContextualMemory:
|
|||||||
ltm: LongTermMemory,
|
ltm: LongTermMemory,
|
||||||
em: EntityMemory,
|
em: EntityMemory,
|
||||||
exm: ExternalMemory,
|
exm: ExternalMemory,
|
||||||
agent: Optional["Agent"] = None,
|
agent: Optional[Agent] = None,
|
||||||
task: Optional["Task"] = None,
|
task: Optional[Task] = None,
|
||||||
):
|
):
|
||||||
self.stm = stm
|
self.stm = stm
|
||||||
self.ltm = ltm
|
self.ltm = ltm
|
||||||
@@ -42,7 +44,7 @@ class ContextualMemory:
|
|||||||
self.exm.agent = self.agent
|
self.exm.agent = self.agent
|
||||||
self.exm.task = self.task
|
self.exm.task = self.task
|
||||||
|
|
||||||
def build_context_for_task(self, task, context) -> str:
|
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||||
"""
|
"""
|
||||||
Automatically builds a minimal, highly relevant set of contextual information
|
Automatically builds a minimal, highly relevant set of contextual information
|
||||||
for a given task.
|
for a given task.
|
||||||
@@ -52,14 +54,14 @@ class ContextualMemory:
|
|||||||
if query == "":
|
if query == "":
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
context = []
|
context_parts = []
|
||||||
context.append(self._fetch_ltm_context(task.description))
|
context_parts.append(self._fetch_ltm_context(task.description))
|
||||||
context.append(self._fetch_stm_context(query))
|
context_parts.append(self._fetch_stm_context(query))
|
||||||
context.append(self._fetch_entity_context(query))
|
context_parts.append(self._fetch_entity_context(query))
|
||||||
context.append(self._fetch_external_context(query))
|
context_parts.append(self._fetch_external_context(query))
|
||||||
return "\n".join(filter(None, context))
|
return "\n".join(filter(None, context_parts))
|
||||||
|
|
||||||
def _fetch_stm_context(self, query) -> str:
|
def _fetch_stm_context(self, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||||
formatted as bullet points.
|
formatted as bullet points.
|
||||||
@@ -74,7 +76,7 @@ class ContextualMemory:
|
|||||||
)
|
)
|
||||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||||
|
|
||||||
def _fetch_ltm_context(self, task) -> Optional[str]:
|
def _fetch_ltm_context(self, task: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
Fetches historical data or insights from LTM that are relevant to the task's description and expected_output,
|
||||||
formatted as bullet points.
|
formatted as bullet points.
|
||||||
@@ -90,14 +92,16 @@ class ContextualMemory:
|
|||||||
formatted_results = [
|
formatted_results = [
|
||||||
suggestion
|
suggestion
|
||||||
for result in ltm_results
|
for result in ltm_results
|
||||||
for suggestion in result["metadata"]["suggestions"] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
for suggestion in result["metadata"]["suggestions"]
|
||||||
]
|
]
|
||||||
formatted_results = list(dict.fromkeys(formatted_results))
|
formatted_results = list(dict.fromkeys(formatted_results))
|
||||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
formatted_results_str = "\n".join(
|
||||||
|
[f"- {result}" for result in formatted_results]
|
||||||
|
)
|
||||||
|
|
||||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
return f"Historical Data:\n{formatted_results_str}" if ltm_results else ""
|
||||||
|
|
||||||
def _fetch_entity_context(self, query) -> str:
|
def _fetch_entity_context(self, query: str) -> str:
|
||||||
"""
|
"""
|
||||||
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
Fetches relevant entity information from Entity Memory related to the task's description and expected_output,
|
||||||
formatted as bullet points.
|
formatted as bullet points.
|
||||||
@@ -107,7 +111,7 @@ class ContextualMemory:
|
|||||||
|
|
||||||
em_results = self.em.search(query)
|
em_results = self.em.search(query)
|
||||||
formatted_results = "\n".join(
|
formatted_results = "\n".join(
|
||||||
[f"- {result['context']}" for result in em_results] # type: ignore # Invalid index type "str" for "str"; expected type "SupportsIndex | slice"
|
[f"- {result['context']}" for result in em_results]
|
||||||
)
|
)
|
||||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user