Improve handling of optional configurations in memory and storage

- Initialize contextual_memory in src/crewai/agent.py and src/crewai/crew.py
- Make UserMemory optional and add checks in src/crewai/memory/contextual/contextual_memory.py
- Add crew checks in src/crewai/memory/entity/entity_memory.py and
  src/crewai/memory/short_term/short_term_memory.py
- Allow optional storage_path in src/crewai/memory/storage/base_rag_storage.py
- Update storage classes to accept optional db_path in:
  src/crewai/memory/storage/kickoff_task_outputs_storage.py,
  src/crewai/memory/storage/ltm_sqlite_storage.py, and
  src/crewai/memory/storage/mem0_storage.py
- Modify src/crewai/memory/storage/rag_storage.py to use storage_path
- Enhance src/crewai/utilities/embedding_configurator.py to handle missing providers
This commit is contained in:
Arnaud Gelas
2024-12-15 11:19:31 +01:00
committed by Devin AI
parent 12245d66a7
commit 4274cde583
11 changed files with 48 additions and 31 deletions

View File

@@ -294,14 +294,7 @@ class Agent(BaseAgent):
) )
if self.crew and self.crew.memory: if self.crew and self.crew.memory:
contextual_memory = ContextualMemory( memory = self.crew.contextual_memory.build_context_for_task(task, context)
self.crew.memory_config,
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._user_memory,
)
memory = contextual_memory.build_context_for_task(task, context)
if memory.strip() != "": if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory) task_prompt += self.i18n.slice("memory").format(memory=memory)

View File

@@ -25,6 +25,7 @@ from crewai.crews.crew_output import CrewOutput
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory from crewai.memory.short_term.short_term_memory import ShortTermMemory
@@ -278,6 +279,13 @@ class Crew(BaseModel):
) )
else: else:
self._user_memory = None self._user_memory = None
self.contextual_memory = ContextualMemory(
memory_config=self.memory_config,
stm=self._short_term_memory,
ltm=self._long_term_memory,
em=self._entity_memory,
um=self._user_memory,
)
return self return self
@model_validator(mode="after") @model_validator(mode="after")

View File

@@ -10,7 +10,7 @@ class ContextualMemory:
stm: ShortTermMemory, stm: ShortTermMemory,
ltm: LongTermMemory, ltm: LongTermMemory,
em: EntityMemory, em: EntityMemory,
um: UserMemory, um: Optional[UserMemory],
): ):
if memory_config is not None: if memory_config is not None:
self.memory_provider = memory_config.get("provider") self.memory_provider = memory_config.get("provider")
@@ -94,6 +94,8 @@ class ContextualMemory:
Returns: Returns:
str: Formatted user memories as bullet points, or an empty string if none found. str: Formatted user memories as bullet points, or an empty string if none found.
""" """
if not self.um:
return ""
user_memories = self.um.search(query) user_memories = self.um.search(query)
if not user_memories: if not user_memories:
return "" return ""

View File

@@ -11,7 +11,7 @@ class EntityMemory(Memory):
""" """
def __init__(self, crew=None, embedder_config=None, storage=None, path=None): def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
if hasattr(crew, "memory_config") and crew.memory_config is not None: if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
self.memory_provider = crew.memory_config.get("provider") self.memory_provider = crew.memory_config.get("provider")
else: else:
self.memory_provider = None self.memory_provider = None

View File

@@ -15,7 +15,7 @@ class ShortTermMemory(Memory):
""" """
def __init__(self, crew=None, embedder_config=None, storage=None, path=None): def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
if hasattr(crew, "memory_config") and crew.memory_config is not None: if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
self.memory_provider = crew.memory_config.get("provider") self.memory_provider = crew.memory_config.get("provider")
else: else:
self.memory_provider = None self.memory_provider = None

View File

@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from crewai.utilities.paths import db_storage_path
class BaseRAGStorage(ABC): class BaseRAGStorage(ABC):
""" """
@@ -12,11 +15,13 @@ class BaseRAGStorage(ABC):
def __init__( def __init__(
self, self,
type: str, type: str,
storage_path: Optional[Path] = None,
allow_reset: bool = True, allow_reset: bool = True,
embedder_config: Optional[Any] = None, embedder_config: Optional[Any] = None,
crew: Any = None, crew: Any = None,
): ):
self.type = type self.type = type
self.storage_path = storage_path if storage_path else db_storage_path()
self.allow_reset = allow_reset self.allow_reset = allow_reset
self.embedder_config = embedder_config self.embedder_config = embedder_config
self.crew = crew self.crew = crew

View File

@@ -13,10 +13,12 @@ class KickoffTaskOutputsSQLiteStorage:
An updated SQLite storage class for kickoff task outputs storage. An updated SQLite storage class for kickoff task outputs storage.
""" """
def __init__( def __init__(self, db_path: Optional[str] = None) -> None:
self, db_path: str = f"{db_storage_path()}/latest_kickoff_task_outputs.db" self.db_path = (
) -> None: db_path
self.db_path = db_path if db_path
else f"{db_storage_path()}/latest_kickoff_task_outputs.db"
)
self._printer: Printer = Printer() self._printer: Printer = Printer()
self._initialize_db() self._initialize_db()

View File

@@ -11,10 +11,10 @@ class LTMSQLiteStorage:
An updated SQLite storage class for LTM data storage. An updated SQLite storage class for LTM data storage.
""" """
def __init__( def __init__(self, db_path: Optional[str] = None) -> None:
self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db" self.db_path = (
) -> None: db_path if db_path else f"{db_storage_path()}/latest_long_term_memories.db"
self.db_path = db_path )
self._printer: Printer = Printer() self._printer: Printer = Printer()
self._initialize_db() self._initialize_db()

View File

@@ -19,7 +19,7 @@ class Mem0Storage(Storage):
self.memory_type = type self.memory_type = type
self.crew = crew self.crew = crew
self.memory_config = crew.memory_config self.memory_config = crew.memory_config if crew else None
# User ID is required for user memory type "user" since it's used as a unique identifier for the user. # User ID is required for user memory type "user" since it's used as a unique identifier for the user.
user_id = self._get_user_id() user_id = self._get_user_id()
@@ -27,9 +27,10 @@ class Mem0Storage(Storage):
raise ValueError("User ID is required for user memory type") raise ValueError("User ID is required for user memory type")
# API key in memory config overrides the environment variable # API key in memory config overrides the environment variable
mem0_api_key = self.memory_config.get("config", {}).get("api_key") or os.getenv( if self.memory_config and self.memory_config.get("config"):
"MEM0_API_KEY" mem0_api_key = self.memory_config.get("config").get("api_key")
) else:
mem0_api_key = os.getenv("MEM0_API_KEY")
self.memory = MemoryClient(api_key=mem0_api_key) self.memory = MemoryClient(api_key=mem0_api_key)
def _sanitize_role(self, role: str) -> str: def _sanitize_role(self, role: str) -> str:

View File

@@ -11,7 +11,6 @@ from chromadb.api import ClientAPI
from crewai.memory.storage.base_rag_storage import BaseRAGStorage from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
@contextlib.contextmanager @contextlib.contextmanager
@@ -40,9 +39,15 @@ class RAGStorage(BaseRAGStorage):
app: ClientAPI | None = None app: ClientAPI | None = None
def __init__( def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None self,
type,
storage_path=None,
allow_reset=True,
embedder_config=None,
crew=None,
path=None,
): ):
super().__init__(type, allow_reset, embedder_config, crew) super().__init__(type, storage_path, allow_reset, embedder_config, crew)
agents = crew.agents if crew else [] agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents] agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents) agents = "_".join(agents)
@@ -90,7 +95,7 @@ class RAGStorage(BaseRAGStorage):
""" """
Ensures file name does not exceed max allowed by OS Ensures file name does not exceed max allowed by OS
""" """
base_path = f"{db_storage_path()}/{type}" base_path = f"{self.storage_path}/{type}"
if len(file_name) > MAX_FILE_NAME_LENGTH: if len(file_name) > MAX_FILE_NAME_LENGTH:
logging.warning( logging.warning(
@@ -152,7 +157,7 @@ class RAGStorage(BaseRAGStorage):
try: try:
if self.app: if self.app:
self.app.reset() self.app.reset()
shutil.rmtree(f"{db_storage_path()}/{self.type}") shutil.rmtree(f"{self.storage_path}/{self.type}")
self.app = None self.app = None
self.collection = None self.collection = None
except Exception as e: except Exception as e:

View File

@@ -27,7 +27,7 @@ class EmbeddingConfigurator:
if embedder_config is None: if embedder_config is None:
return self._create_default_embedding_function() return self._create_default_embedding_function()
provider = embedder_config.get("provider") provider = embedder_config.get("provider", "")
config = embedder_config.get("config", {}) config = embedder_config.get("config", {})
model_name = config.get("model") model_name = config.get("model")
@@ -38,12 +38,13 @@ class EmbeddingConfigurator:
except Exception as e: except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}") raise ValueError(f"Invalid custom embedding function: {str(e)}")
if provider not in self.embedding_functions: embedding_function = self.embedding_functions.get(provider, None)
if not embedding_function:
raise Exception( raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
) )
return self.embedding_functions[provider](config, model_name) return embedding_function(config, model_name)
@staticmethod @staticmethod
def _create_default_embedding_function(): def _create_default_embedding_function():