diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 999d1d800..891a03775 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -294,14 +294,7 @@ class Agent(BaseAgent): ) if self.crew and self.crew.memory: - contextual_memory = ContextualMemory( - self.crew.memory_config, - self.crew._short_term_memory, - self.crew._long_term_memory, - self.crew._entity_memory, - self.crew._user_memory, - ) - memory = contextual_memory.build_context_for_task(task, context) + memory = self.crew.contextual_memory.build_context_for_task(task, context) if memory.strip() != "": task_prompt += self.i18n.slice("memory").format(memory=memory) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index d488783ea..3f0a00323 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -25,6 +25,7 @@ from crewai.crews.crew_output import CrewOutput from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.llm import LLM +from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.long_term.long_term_memory import LongTermMemory from crewai.memory.short_term.short_term_memory import ShortTermMemory @@ -278,6 +279,13 @@ class Crew(BaseModel): ) else: self._user_memory = None + self.contextual_memory = ContextualMemory( + memory_config=self.memory_config, + stm=self._short_term_memory, + ltm=self._long_term_memory, + em=self._entity_memory, + um=self._user_memory, + ) return self @model_validator(mode="after") diff --git a/src/crewai/memory/contextual/contextual_memory.py b/src/crewai/memory/contextual/contextual_memory.py index cdb9cf836..b7baaa92c 100644 --- a/src/crewai/memory/contextual/contextual_memory.py +++ b/src/crewai/memory/contextual/contextual_memory.py @@ -10,7 +10,7 @@ class ContextualMemory: stm: ShortTermMemory, ltm: LongTermMemory, em: EntityMemory, - um: UserMemory, + um: Optional[UserMemory], ): if memory_config is not None: self.memory_provider = memory_config.get("provider") @@ -94,6 +94,8 @@ class ContextualMemory: Returns: str: Formatted user memories as bullet points, or an empty string if none found. """ + if not self.um: + return "" user_memories = self.um.search(query) if not user_memories: return "" diff --git a/src/crewai/memory/entity/entity_memory.py b/src/crewai/memory/entity/entity_memory.py index 67c72e927..fd4150d2c 100644 --- a/src/crewai/memory/entity/entity_memory.py +++ b/src/crewai/memory/entity/entity_memory.py @@ -11,7 +11,7 @@ class EntityMemory(Memory): """ def __init__(self, crew=None, embedder_config=None, storage=None, path=None): - if hasattr(crew, "memory_config") and crew.memory_config is not None: + if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: self.memory_provider = crew.memory_config.get("provider") else: self.memory_provider = None diff --git a/src/crewai/memory/short_term/short_term_memory.py b/src/crewai/memory/short_term/short_term_memory.py index 4e5fbbb77..dde83df80 100644 --- a/src/crewai/memory/short_term/short_term_memory.py +++ b/src/crewai/memory/short_term/short_term_memory.py @@ -15,7 +15,7 @@ class ShortTermMemory(Memory): """ def __init__(self, crew=None, embedder_config=None, storage=None, path=None): - if hasattr(crew, "memory_config") and crew.memory_config is not None: + if crew and hasattr(crew, "memory_config") and crew.memory_config is not None: self.memory_provider = crew.memory_config.get("provider") else: self.memory_provider = None diff --git a/src/crewai/memory/storage/base_rag_storage.py b/src/crewai/memory/storage/base_rag_storage.py index 10b82ebff..9c26bf293 100644 --- a/src/crewai/memory/storage/base_rag_storage.py +++ b/src/crewai/memory/storage/base_rag_storage.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, List, Optional +from crewai.utilities.paths import db_storage_path + class BaseRAGStorage(ABC): """ @@ -12,11 +15,13 @@ class BaseRAGStorage(ABC): def __init__( self, type: str, + storage_path: Optional[Path] = None, allow_reset: bool = True, embedder_config: Optional[Any] = None, crew: Any = None, ): self.type = type + self.storage_path = storage_path if storage_path else db_storage_path() self.allow_reset = allow_reset self.embedder_config = embedder_config self.crew = crew diff --git a/src/crewai/memory/storage/kickoff_task_outputs_storage.py b/src/crewai/memory/storage/kickoff_task_outputs_storage.py index 26905191c..00e949d39 100644 --- a/src/crewai/memory/storage/kickoff_task_outputs_storage.py +++ b/src/crewai/memory/storage/kickoff_task_outputs_storage.py @@ -13,10 +13,12 @@ class KickoffTaskOutputsSQLiteStorage: An updated SQLite storage class for kickoff task outputs storage. """ - def __init__( - self, db_path: str = f"{db_storage_path()}/latest_kickoff_task_outputs.db" - ) -> None: - self.db_path = db_path + def __init__(self, db_path: Optional[str] = None) -> None: + self.db_path = ( + db_path + if db_path + else f"{db_storage_path()}/latest_kickoff_task_outputs.db" + ) self._printer: Printer = Printer() self._initialize_db() diff --git a/src/crewai/memory/storage/ltm_sqlite_storage.py b/src/crewai/memory/storage/ltm_sqlite_storage.py index 93d993ee6..97ccfa14b 100644 --- a/src/crewai/memory/storage/ltm_sqlite_storage.py +++ b/src/crewai/memory/storage/ltm_sqlite_storage.py @@ -11,10 +11,10 @@ class LTMSQLiteStorage: An updated SQLite storage class for LTM data storage. """ - def __init__( - self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db" - ) -> None: - self.db_path = db_path + def __init__(self, db_path: Optional[str] = None) -> None: + self.db_path = ( + db_path if db_path else f"{db_storage_path()}/latest_long_term_memories.db" + ) self._printer: Printer = Printer() self._initialize_db() diff --git a/src/crewai/memory/storage/mem0_storage.py b/src/crewai/memory/storage/mem0_storage.py index e4e84fab4..75152ded9 100644 --- a/src/crewai/memory/storage/mem0_storage.py +++ b/src/crewai/memory/storage/mem0_storage.py @@ -19,7 +19,7 @@ class Mem0Storage(Storage): self.memory_type = type self.crew = crew - self.memory_config = crew.memory_config + self.memory_config = crew.memory_config if crew else None # User ID is required for user memory type "user" since it's used as a unique identifier for the user. user_id = self._get_user_id() @@ -27,9 +27,10 @@ class Mem0Storage(Storage): raise ValueError("User ID is required for user memory type") # API key in memory config overrides the environment variable - mem0_api_key = self.memory_config.get("config", {}).get("api_key") or os.getenv( - "MEM0_API_KEY" - ) + if self.memory_config and self.memory_config.get("config"): + mem0_api_key = self.memory_config.get("config").get("api_key") + else: + mem0_api_key = os.getenv("MEM0_API_KEY") self.memory = MemoryClient(api_key=mem0_api_key) def _sanitize_role(self, role: str) -> str: diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index fd4c77838..ee3abc438 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -11,7 +11,6 @@ from chromadb.api import ClientAPI from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.utilities import EmbeddingConfigurator from crewai.utilities.constants import MAX_FILE_NAME_LENGTH -from crewai.utilities.paths import db_storage_path @contextlib.contextmanager @@ -40,9 +39,15 @@ class RAGStorage(BaseRAGStorage): app: ClientAPI | None = None def __init__( - self, type, allow_reset=True, embedder_config=None, crew=None, path=None + self, + type, + storage_path=None, + allow_reset=True, + embedder_config=None, + crew=None, + path=None, ): - super().__init__(type, allow_reset, embedder_config, crew) + super().__init__(type, storage_path, allow_reset, embedder_config, crew) agents = crew.agents if crew else [] agents = [self._sanitize_role(agent.role) for agent in agents] agents = "_".join(agents) @@ -90,7 +95,7 @@ class RAGStorage(BaseRAGStorage): """ Ensures file name does not exceed max allowed by OS """ - base_path = f"{db_storage_path()}/{type}" + base_path = f"{self.storage_path}/{type}" if len(file_name) > MAX_FILE_NAME_LENGTH: logging.warning( @@ -152,7 +157,7 @@ class RAGStorage(BaseRAGStorage): try: if self.app: self.app.reset() - shutil.rmtree(f"{db_storage_path()}/{self.type}") + shutil.rmtree(f"{self.storage_path}/{self.type}") self.app = None self.collection = None except Exception as e: diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index 44e832ec2..9b6c7f142 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -27,7 +27,7 @@ class EmbeddingConfigurator: if embedder_config is None: return self._create_default_embedding_function() - provider = embedder_config.get("provider") + provider = embedder_config.get("provider", "") config = embedder_config.get("config", {}) model_name = config.get("model") @@ -38,12 +38,13 @@ class EmbeddingConfigurator: except Exception as e: raise ValueError(f"Invalid custom embedding function: {str(e)}") - if provider not in self.embedding_functions: + embedding_function = self.embedding_functions.get(provider, None) + if not embedding_function: raise Exception( f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" ) - return self.embedding_functions[provider](config, model_name) + return embedding_function(config, model_name) @staticmethod def _create_default_embedding_function():