From ee400973e91c965648efdbbe6dd9d412bd9eaf50 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Fri, 7 Feb 2025 12:27:00 -0500 Subject: [PATCH] fix failing memory tests --- src/crewai/memory/entity/entity_memory.py | 2 +- src/crewai/memory/long_term/long_term_memory.py | 2 +- src/crewai/memory/memory.py | 6 ++++-- src/crewai/memory/short_term/short_term_memory.py | 4 ++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index fd4150d2c..bab5068cb 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -36,7 +36,7 @@ class EntityMemory(Memory): path=path, ) ) - super().__init__(storage) + super().__init__(storage=storage) def save(self, item: EntityMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" """Saves an entity item into the SQLite storage.""" diff --git a/src/crewai/memory/long_term/long_term_memory.py b/src/crewai/memory/long_term/long_term_memory.py index 656709ac9..94aac3a97 100644 --- a/src/crewai/memory/long_term/long_term_memory.py +++ b/src/crewai/memory/long_term/long_term_memory.py @@ -17,7 +17,7 @@ class LongTermMemory(Memory): def __init__(self, storage=None, path=None): if not storage: storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() - super().__init__(storage) + super().__init__(storage=storage) def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" metadata = item.metadata diff --git a/src/crewai/memory/memory.py b/src/crewai/memory/memory.py index 752a683b4..51a700323 100644 --- a/src/crewai/memory/memory.py +++ b/src/crewai/memory/memory.py @@ -10,8 +10,10 @@ class Memory(BaseModel): Base class for memory, now supporting agent tags and generic metadata. """ - def __init__(self, storage: Union[RAGStorage, Any]): - self.storage = storage + storage: Any + + def __init__(self, storage: Any, **data: Any): + super().__init__(storage=storage, **data) def save( self, diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index 4e5fbbb77..e7fd48140 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -15,7 +15,7 @@ class ShortTermMemory(Memory): """ def __init__(self, crew=None, embedder_config=None, storage=None, path=None): - if hasattr(crew, "memory_config") and crew.memory_config is not None: + if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: self.memory_provider = crew.memory_config.get("provider") else: self.memory_provider = None @@ -39,7 +39,7 @@ class ShortTermMemory(Memory): path=path, ) ) - super().__init__(storage) + super().__init__(storage=storage) def save( self,