Compare commits

...

1 Commits

Author SHA1 Message Date
Lorenze Jay
319f0301ef WIP fixed mypy src types 2024-07-30 08:32:59 -07:00
5 changed files with 27 additions and 17 deletions

View File

@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Optional
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.utilities.converter import ConverterError
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities import I18N
@@ -39,18 +38,17 @@ class CrewAgentExecutorMixin:
and "Action: Delegate work to coworker" not in output.log
):
try:
memory = ShortTermMemoryItem(
data=output.log,
agent=self.crew_agent.role,
metadata={
"observation": self.task.description,
},
)
if (
hasattr(self.crew, "_short_term_memory")
and self.crew._short_term_memory
):
self.crew._short_term_memory.save(memory)
self.crew._short_term_memory.save(
value=output.log,
metadata={
"observation": self.task.description,
},
agent=self.crew_agent.role,
)
except Exception as e:
print(f"Failed to add to short term memory: {e}")
pass

View File

@@ -1,3 +1,4 @@
from typing import Any, Dict, Optional
from crewai.memory.memory import Memory
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem
from crewai.memory.storage.rag_storage import RAGStorage
@@ -18,7 +19,14 @@ class ShortTermMemory(Memory):
)
super().__init__(storage)
def save(self, item: ShortTermMemoryItem) -> None:
def save(
self,
value: Any,
metadata: Optional[Dict[str, Any]] = None,
agent: Optional[str] = None,
) -> None:
item = ShortTermMemoryItem(data=value, metadata=metadata, agent=agent)
super().save(value=item.data, metadata=item.metadata, agent=item.agent)
def search(self, query: str, score_threshold: float = 0.35):

View File

@@ -3,7 +3,10 @@ from typing import Any, Dict, Optional
class ShortTermMemoryItem:
def __init__(
self, data: Any, agent: str, metadata: Optional[Dict[str, Any]] = None
self,
data: Any,
agent: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
self.data = data
self.agent = agent

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict
class Storage:
"""Abstract base class defining the storage interface"""
def save(self, key: str, value: Any, metadata: Dict[str, Any]) -> None:
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
pass
def search(self, key: str) -> Dict[str, Any]: # type: ignore

View File

@@ -23,10 +23,7 @@ def short_term_memory():
expected_output="A list of relevant URLs based on the search query.",
agent=agent,
)
return ShortTermMemory(crew=Crew(
agents=[agent],
tasks=[task]
))
return ShortTermMemory(crew=Crew(agents=[agent], tasks=[task]))
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -38,7 +35,11 @@ def test_save_and_search(short_term_memory):
agent="test_agent",
metadata={"task": "test_task"},
)
short_term_memory.save(memory)
short_term_memory.save(
value=memory.data,
metadata=memory.metadata,
agent=memory.agent,
)
find = short_term_memory.search("test value", score_threshold=0.01)[0]
assert find["context"] == memory.data, "Data value mismatch."