feat(memory): adds support for customizable memory interface (#1339)

* feat(memory): adds support for customizing crew storage

* chore: allow overwriting the crew memory configuration

* docs: update custom storage usage

* fix(lint): use correct syntax

* fix: type check warning

* fix: type check warnings

* fix(test): address agent default failing test

* fix(lint). address type checker error

* Update crew.py

---------

Co-authored-by: João Moura <joaomdmoura@gmail.com>
This commit is contained in:
Ayo Ayibiowu
2024-09-22 22:03:23 +02:00
committed by GitHub
parent e3c7c0185d
commit 91ff331fec
5 changed files with 71 additions and 15 deletions

View File

@@ -110,6 +110,18 @@ class Crew(BaseModel):
default=False,
description="Whether the crew should use memory to store memories of it's execution",
)
short_term_memory: Optional[InstanceOf[ShortTermMemory]] = Field(
default=None,
description="An Instance of the ShortTermMemory to be used by the Crew",
)
long_term_memory: Optional[InstanceOf[LongTermMemory]] = Field(
default=None,
description="An Instance of the LongTermMemory to be used by the Crew",
)
entity_memory: Optional[InstanceOf[EntityMemory]] = Field(
default=None,
description="An Instance of the EntityMemory to be used by the Crew",
)
embedder: Optional[dict] = Field(
default={"provider": "openai"},
description="Configuration for the embedder to be used for the crew.",
@@ -212,11 +224,11 @@ class Crew(BaseModel):
def create_crew_memory(self) -> "Crew":
"""Set private attributes."""
if self.memory:
self._long_term_memory = LongTermMemory()
self._short_term_memory = ShortTermMemory(
self._long_term_memory = self.long_term_memory if self.long_term_memory else LongTermMemory()
self._short_term_memory = self.short_term_memory if self.short_term_memory else ShortTermMemory(
crew=self, embedder_config=self.embedder
)
self._entity_memory = EntityMemory(crew=self, embedder_config=self.embedder)
self._entity_memory = self.entity_memory if self.entity_memory else EntityMemory(crew=self, embedder_config=self.embedder)
return self
@model_validator(mode="after")

View File

@@ -10,12 +10,13 @@ class EntityMemory(Memory):
Inherits from the Memory class.
"""
def __init__(self, crew=None, embedder_config=None):
storage = RAGStorage(
type="entities",
allow_reset=False,
embedder_config=embedder_config,
crew=crew,
def __init__(self, crew=None, embedder_config=None, storage=None):
storage = (
storage
if storage
else RAGStorage(
type="entities", allow_reset=False, embedder_config=embedder_config, crew=crew
)
)
super().__init__(storage)

View File

@@ -14,8 +14,8 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances.
"""
def __init__(self):
storage = LTMSQLiteStorage()
def __init__(self, storage=None):
storage = storage if storage else LTMSQLiteStorage()
super().__init__(storage)
def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory"

View File

@@ -13,9 +13,13 @@ class ShortTermMemory(Memory):
MemoryItem instances.
"""
def __init__(self, crew=None, embedder_config=None):
storage = RAGStorage(
type="short_term", embedder_config=embedder_config, crew=crew
def __init__(self, crew=None, embedder_config=None, storage=None):
storage = (
storage
if storage
else RAGStorage(
type="short_term", embedder_config=embedder_config, crew=crew
)
)
super().__init__(storage)