diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 88d33c09a..67c72e927 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -10,7 +10,7 @@ class EntityMemory(Memory): Inherits from the Memory class. """ - def __init__(self, crew=None, embedder_config=None, storage=None): + def __init__(self, crew=None, embedder_config=None, storage=None, path=None): if hasattr(crew, "memory_config") and crew.memory_config is not None: self.memory_provider = crew.memory_config.get("provider") else: @@ -33,6 +33,7 @@ class EntityMemory(Memory): allow_reset=True, embedder_config=embedder_config, crew=crew, + path=path, ) ) super().__init__(storage) diff --git a/src/crewai/memory/long_term/long_term_memory.py b/src/crewai/memory/long_term/long_term_memory.py index b9c36bdc9..656709ac9 100644 --- a/src/crewai/memory/long_term/long_term_memory.py +++ b/src/crewai/memory/long_term/long_term_memory.py @@ -14,8 +14,9 @@ class LongTermMemory(Memory): LongTermMemoryItem instances. """ - def __init__(self, storage=None): - storage = storage if storage else LTMSQLiteStorage() + def __init__(self, storage=None, path=None): + if not storage: + storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() super().__init__(storage) def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index 67a568d63..4ade7eb93 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -13,7 +13,7 @@ class ShortTermMemory(Memory): MemoryItem instances. """ - def __init__(self, crew=None, embedder_config=None, storage=None): + def __init__(self, crew=None, embedder_config=None, storage=None, path=None): if hasattr(crew, "memory_config") and crew.memory_config is not None: self.memory_provider = crew.memory_config.get("provider") else: @@ -32,7 +32,7 @@ class ShortTermMemory(Memory): storage if storage else RAGStorage( - type="short_term", embedder_config=embedder_config, crew=crew + type="short_term", embedder_config=embedder_config, crew=crew, path=path ) ) super().__init__(storage) diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 4023cf558..ded340a19 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -37,7 +37,7 @@ class RAGStorage(BaseRAGStorage): app: ClientAPI | None = None - def __init__(self, type, allow_reset=True, embedder_config=None, crew=None): + def __init__(self, type, allow_reset=True, embedder_config=None, crew=None, path=None): super().__init__(type, allow_reset, embedder_config, crew) agents = crew.agents if crew else [] agents = [self._sanitize_role(agent.role) for agent in agents] @@ -47,6 +47,7 @@ class RAGStorage(BaseRAGStorage): self.type = type self.allow_reset = allow_reset + self.path = path self._initialize_app() def _set_embedder_config(self): @@ -59,7 +60,7 @@ class RAGStorage(BaseRAGStorage): self._set_embedder_config() chroma_client = chromadb.PersistentClient( - path=f"{db_storage_path()}/{self.type}/{self.agents}", + path=self.path if self.path else f"{db_storage_path()}/{self.type}/{self.agents}", settings=Settings(allow_reset=self.allow_reset), )