Improve handling of optional configurations in memory and storage

- Initialize contextual_memory in src/crewai/agent.py and src/crewai/crew.py
- Make UserMemory optional and add checks in src/crewai/memory/contextual/contextual_memory.py
- Add crew checks in src/crewai/memory/entity/entity_memory.py and
  src/crewai/memory/short_term/short_term_memory.py
- Allow optional storage_path in src/crewai/memory/storage/base_rag_storage.py
- Update storage classes to accept optional db_path in:
  src/crewai/memory/storage/kickoff_task_outputs_storage.py,
  src/crewai/memory/storage/ltm_sqlite_storage.py, and
  src/crewai/memory/storage/mem0_storage.py
- Modify src/crewai/memory/storage/rag_storage.py to use storage_path
- Enhance src/crewai/utilities/embedding_configurator.py to handle missing providers
This commit is contained in:
Arnaud Gelas
2024-12-15 11:19:31 +01:00
committed by Devin AI
parent 12245d66a7
commit 4274cde583
11 changed files with 48 additions and 31 deletions

View File

@@ -294,14 +294,7 @@ class Agent(BaseAgent):
)
if self.crew and self.crew.memory:
contextual_memory = ContextualMemory(
self.crew.memory_config,
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._user_memory,
)
memory = contextual_memory.build_context_for_task(task, context)
memory = self.crew.contextual_memory.build_context_for_task(task, context)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)

View File

@@ -25,6 +25,7 @@ from crewai.crews.crew_output import CrewOutput
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
@@ -278,6 +279,13 @@ class Crew(BaseModel):
)
else:
self._user_memory = None
self.contextual_memory = ContextualMemory(
memory_config=self.memory_config,
stm=self._short_term_memory,
ltm=self._long_term_memory,
em=self._entity_memory,
um=self._user_memory,
)
return self
@model_validator(mode="after")

View File

@@ -10,7 +10,7 @@ class ContextualMemory:
stm: ShortTermMemory,
ltm: LongTermMemory,
em: EntityMemory,
um: UserMemory,
um: Optional[UserMemory],
):
if memory_config is not None:
self.memory_provider = memory_config.get("provider")
@@ -94,6 +94,8 @@ class ContextualMemory:
Returns:
str: Formatted user memories as bullet points, or an empty string if none found.
"""
if not self.um:
return ""
user_memories = self.um.search(query)
if not user_memories:
return ""

View File

@@ -11,7 +11,7 @@ class EntityMemory(Memory):
"""
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
if hasattr(crew, "memory_config") and crew.memory_config is not None:
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
self.memory_provider = crew.memory_config.get("provider")
else:
self.memory_provider = None

View File

@@ -15,7 +15,7 @@ class ShortTermMemory(Memory):
"""
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
if hasattr(crew, "memory_config") and crew.memory_config is not None:
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
self.memory_provider = crew.memory_config.get("provider")
else:
self.memory_provider = None

View File

@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional
from crewai.utilities.paths import db_storage_path
class BaseRAGStorage(ABC):
"""
@@ -12,11 +15,13 @@ class BaseRAGStorage(ABC):
def __init__(
self,
type: str,
storage_path: Optional[Path] = None,
allow_reset: bool = True,
embedder_config: Optional[Any] = None,
crew: Any = None,
):
self.type = type
self.storage_path = storage_path if storage_path else db_storage_path()
self.allow_reset = allow_reset
self.embedder_config = embedder_config
self.crew = crew

View File

@@ -13,10 +13,12 @@ class KickoffTaskOutputsSQLiteStorage:
An updated SQLite storage class for kickoff task outputs storage.
"""
def __init__(
self, db_path: str = f"{db_storage_path()}/latest_kickoff_task_outputs.db"
) -> None:
self.db_path = db_path
def __init__(self, db_path: Optional[str] = None) -> None:
self.db_path = (
db_path
if db_path
else f"{db_storage_path()}/latest_kickoff_task_outputs.db"
)
self._printer: Printer = Printer()
self._initialize_db()

View File

@@ -11,10 +11,10 @@ class LTMSQLiteStorage:
An updated SQLite storage class for LTM data storage.
"""
def __init__(
self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db"
) -> None:
self.db_path = db_path
def __init__(self, db_path: Optional[str] = None) -> None:
self.db_path = (
db_path if db_path else f"{db_storage_path()}/latest_long_term_memories.db"
)
self._printer: Printer = Printer()
self._initialize_db()

View File

@@ -19,7 +19,7 @@ class Mem0Storage(Storage):
self.memory_type = type
self.crew = crew
self.memory_config = crew.memory_config
self.memory_config = crew.memory_config if crew else None
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
user_id = self._get_user_id()
@@ -27,9 +27,10 @@ class Mem0Storage(Storage):
raise ValueError("User ID is required for user memory type")
# API key in memory config overrides the environment variable
mem0_api_key = self.memory_config.get("config", {}).get("api_key") or os.getenv(
"MEM0_API_KEY"
)
if self.memory_config and self.memory_config.get("config"):
mem0_api_key = self.memory_config.get("config").get("api_key")
else:
mem0_api_key = os.getenv("MEM0_API_KEY")
self.memory = MemoryClient(api_key=mem0_api_key)
def _sanitize_role(self, role: str) -> str:

View File

@@ -11,7 +11,6 @@ from chromadb.api import ClientAPI
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
@contextlib.contextmanager
@@ -40,9 +39,15 @@ class RAGStorage(BaseRAGStorage):
app: ClientAPI | None = None
def __init__(
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
self,
type,
storage_path=None,
allow_reset=True,
embedder_config=None,
crew=None,
path=None,
):
super().__init__(type, allow_reset, embedder_config, crew)
super().__init__(type, storage_path, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
@@ -90,7 +95,7 @@ class RAGStorage(BaseRAGStorage):
"""
Ensures file name does not exceed max allowed by OS
"""
base_path = f"{db_storage_path()}/{type}"
base_path = f"{self.storage_path}/{type}"
if len(file_name) > MAX_FILE_NAME_LENGTH:
logging.warning(
@@ -152,7 +157,7 @@ class RAGStorage(BaseRAGStorage):
try:
if self.app:
self.app.reset()
shutil.rmtree(f"{db_storage_path()}/{self.type}")
shutil.rmtree(f"{self.storage_path}/{self.type}")
self.app = None
self.collection = None
except Exception as e:

View File

@@ -27,7 +27,7 @@ class EmbeddingConfigurator:
if embedder_config is None:
return self._create_default_embedding_function()
provider = embedder_config.get("provider")
provider = embedder_config.get("provider", "")
config = embedder_config.get("config", {})
model_name = config.get("model")
@@ -38,12 +38,13 @@ class EmbeddingConfigurator:
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
if provider not in self.embedding_functions:
embedding_function = self.embedding_functions.get(provider, None)
if not embedding_function:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
)
return self.embedding_functions[provider](config, model_name)
return embedding_function(config, model_name)
@staticmethod
def _create_default_embedding_function():