mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
WIP fixed mypy src types (#1036)
This commit is contained in:
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
|
|
||||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
|
||||||
from crewai.utilities.converter import ConverterError
|
from crewai.utilities.converter import ConverterError
|
||||||
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
|
||||||
from crewai.utilities import I18N
|
from crewai.utilities import I18N
|
||||||
@@ -39,18 +38,17 @@ class CrewAgentExecutorMixin:
|
|||||||
and "Action: Delegate work to coworker" not in output.log
|
and "Action: Delegate work to coworker" not in output.log
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
memory = ShortTermMemoryItem(
|
|
||||||
data=output.log,
|
|
||||||
agent=self.crew_agent.role,
|
|
||||||
metadata={
|
|
||||||
"observation": self.task.description,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
hasattr(self.crew, "_short_term_memory")
|
hasattr(self.crew, "_short_term_memory")
|
||||||
and self.crew._short_term_memory
|
and self.crew._short_term_memory
|
||||||
):
|
):
|
||||||
self.crew._short_term_memory.save(memory)
|
self.crew._short_term_memory.save(
|
||||||
|
value=output.log,
|
||||||
|
metadata={
|
||||||
|
"observation": self.task.description,
|
||||||
|
},
|
||||||
|
agent=self.crew_agent.role,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to add to short term memory: {e}")
|
print(f"Failed to add to short term memory: {e}")
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from typing import Any, Dict, Optional
|
||||||
from crewai.memory.memory import Memory
|
from crewai.memory.memory import Memory
|
||||||
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
|
||||||
from crewai.memory.storage.rag_storage import RAGStorage
|
from crewai.memory.storage.rag_storage import RAGStorage
|
||||||
@@ -18,7 +19,14 @@ class ShortTermMemory(Memory):
|
|||||||
)
|
)
|
||||||
super().__init__(storage)
|
super().__init__(storage)
|
||||||
|
|
||||||
def save(self, item: ShortTermMemoryItem) -> None:
|
def save(
|
||||||
|
self,
|
||||||
|
value: Any,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
agent: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
|
||||||
|
|
||||||
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
|
||||||
|
|
||||||
def search(self, query: str, score_threshold: float = 0.35):
|
def search(self, query: str, score_threshold: float = 0.35):
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
class ShortTermMemoryItem:
|
class ShortTermMemoryItem:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, data: Any, agent: str, metadata: Optional[Dict[str, Any]] = None
|
self,
|
||||||
|
data: Any,
|
||||||
|
agent: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
self.data = data
|
self.data = data
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, Dict
|
|||||||
class Storage:
|
class Storage:
|
||||||
"""Abstract base class defining the storage interface"""
|
"""Abstract base class defining the storage interface"""
|
||||||
|
|
||||||
def save(self, key: str, value: Any, metadata: Dict[str, Any]) -> None:
|
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def search(self, key: str) -> Dict[str, Any]: # type: ignore
|
def search(self, key: str) -> Dict[str, Any]: # type: ignore
|
||||||
|
|||||||
@@ -23,10 +23,7 @@ def short_term_memory():
|
|||||||
expected_output="A list of relevant URLs based on the search query.",
|
expected_output="A list of relevant URLs based on the search query.",
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
return ShortTermMemory(crew=Crew(
|
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
|
||||||
agents=[agent],
|
|
||||||
tasks=[task]
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@@ -38,7 +35,11 @@ def test_save_and_search(short_term_memory):
|
|||||||
agent="test_agent",
|
agent="test_agent",
|
||||||
metadata={"task": "test_task"},
|
metadata={"task": "test_task"},
|
||||||
)
|
)
|
||||||
short_term_memory.save(memory)
|
short_term_memory.save(
|
||||||
|
value=memory.data,
|
||||||
|
metadata=memory.metadata,
|
||||||
|
agent=memory.agent,
|
||||||
|
)
|
||||||
|
|
||||||
find = short_term_memory.search("test value", score_threshold=0.01)[0]
|
find = short_term_memory.search("test value", score_threshold=0.01)[0]
|
||||||
assert find["context"] == memory.data, "Data value mismatch."
|
assert find["context"] == memory.data, "Data value mismatch."
|
||||||
|
|||||||
Reference in New Issue
Block a user