mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Improve handling of optional configurations in memory and storage
- Initialize contextual_memory in src/crewai/agent.py and src/crewai/crew.py - Make UserMemory optional and add checks in src/crewai/memory/contextual/contextual_memory.py - Add crew checks in src/crewai/memory/entity/entity_memory.py and src/crewai/memory/short_term/short_term_memory.py - Allow optional storage_path in src/crewai/memory/storage/base_rag_storage.py - Update storage classes to accept optional db_path in: src/crewai/memory/storage/kickoff_task_outputs_storage.py, src/crewai/memory/storage/ltm_sqlite_storage.py, and src/crewai/memory/storage/mem0_storage.py - Modify src/crewai/memory/storage/rag_storage.py to use storage_path - Enhance src/crewai/utilities/embedding_configurator.py to handle missing providers
This commit is contained in:
@@ -294,14 +294,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
if self.crew and self.crew.memory:
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew.memory_config,
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._user_memory,
|
||||
)
|
||||
memory = contextual_memory.build_context_for_task(task, context)
|
||||
memory = self.crew.contextual_memory.build_context_for_task(task, context)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
@@ -278,6 +279,13 @@ class Crew(BaseModel):
|
||||
)
|
||||
else:
|
||||
self._user_memory = None
|
||||
self.contextual_memory = ContextualMemory(
|
||||
memory_config=self.memory_config,
|
||||
stm=self._short_term_memory,
|
||||
ltm=self._long_term_memory,
|
||||
em=self._entity_memory,
|
||||
um=self._user_memory,
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@@ -10,7 +10,7 @@ class ContextualMemory:
|
||||
stm: ShortTermMemory,
|
||||
ltm: LongTermMemory,
|
||||
em: EntityMemory,
|
||||
um: UserMemory,
|
||||
um: Optional[UserMemory],
|
||||
):
|
||||
if memory_config is not None:
|
||||
self.memory_provider = memory_config.get("provider")
|
||||
@@ -94,6 +94,8 @@ class ContextualMemory:
|
||||
Returns:
|
||||
str: Formatted user memories as bullet points, or an empty string if none found.
|
||||
"""
|
||||
if not self.um:
|
||||
return ""
|
||||
user_memories = self.um.search(query)
|
||||
if not user_memories:
|
||||
return ""
|
||||
|
||||
@@ -11,7 +11,7 @@ class EntityMemory(Memory):
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
self.memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
|
||||
@@ -15,7 +15,7 @@ class ShortTermMemory(Memory):
|
||||
"""
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||
self.memory_provider = crew.memory_config.get("provider")
|
||||
else:
|
||||
self.memory_provider = None
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
"""
|
||||
@@ -12,11 +15,13 @@ class BaseRAGStorage(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
type: str,
|
||||
storage_path: Optional[Path] = None,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: Optional[Any] = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
self.storage_path = storage_path if storage_path else db_storage_path()
|
||||
self.allow_reset = allow_reset
|
||||
self.embedder_config = embedder_config
|
||||
self.crew = crew
|
||||
|
||||
@@ -13,10 +13,12 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
An updated SQLite storage class for kickoff task outputs storage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, db_path: str = f"{db_storage_path()}/latest_kickoff_task_outputs.db"
|
||||
) -> None:
|
||||
self.db_path = db_path
|
||||
def __init__(self, db_path: Optional[str] = None) -> None:
|
||||
self.db_path = (
|
||||
db_path
|
||||
if db_path
|
||||
else f"{db_storage_path()}/latest_kickoff_task_outputs.db"
|
||||
)
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
|
||||
@@ -11,10 +11,10 @@ class LTMSQLiteStorage:
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db"
|
||||
) -> None:
|
||||
self.db_path = db_path
|
||||
def __init__(self, db_path: Optional[str] = None) -> None:
|
||||
self.db_path = (
|
||||
db_path if db_path else f"{db_storage_path()}/latest_long_term_memories.db"
|
||||
)
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class Mem0Storage(Storage):
|
||||
|
||||
self.memory_type = type
|
||||
self.crew = crew
|
||||
self.memory_config = crew.memory_config
|
||||
self.memory_config = crew.memory_config if crew else None
|
||||
|
||||
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
||||
user_id = self._get_user_id()
|
||||
@@ -27,9 +27,10 @@ class Mem0Storage(Storage):
|
||||
raise ValueError("User ID is required for user memory type")
|
||||
|
||||
# API key in memory config overrides the environment variable
|
||||
mem0_api_key = self.memory_config.get("config", {}).get("api_key") or os.getenv(
|
||||
"MEM0_API_KEY"
|
||||
)
|
||||
if self.memory_config and self.memory_config.get("config"):
|
||||
mem0_api_key = self.memory_config.get("config").get("api_key")
|
||||
else:
|
||||
mem0_api_key = os.getenv("MEM0_API_KEY")
|
||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
|
||||
@@ -11,7 +11,6 @@ from chromadb.api import ClientAPI
|
||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.utilities import EmbeddingConfigurator
|
||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -40,9 +39,15 @@ class RAGStorage(BaseRAGStorage):
|
||||
app: ClientAPI | None = None
|
||||
|
||||
def __init__(
|
||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
||||
self,
|
||||
type,
|
||||
storage_path=None,
|
||||
allow_reset=True,
|
||||
embedder_config=None,
|
||||
crew=None,
|
||||
path=None,
|
||||
):
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
super().__init__(type, storage_path, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
@@ -90,7 +95,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
"""
|
||||
Ensures file name does not exceed max allowed by OS
|
||||
"""
|
||||
base_path = f"{db_storage_path()}/{type}"
|
||||
base_path = f"{self.storage_path}/{type}"
|
||||
|
||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||
logging.warning(
|
||||
@@ -152,7 +157,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
try:
|
||||
if self.app:
|
||||
self.app.reset()
|
||||
shutil.rmtree(f"{db_storage_path()}/{self.type}")
|
||||
shutil.rmtree(f"{self.storage_path}/{self.type}")
|
||||
self.app = None
|
||||
self.collection = None
|
||||
except Exception as e:
|
||||
|
||||
@@ -27,7 +27,7 @@ class EmbeddingConfigurator:
|
||||
if embedder_config is None:
|
||||
return self._create_default_embedding_function()
|
||||
|
||||
provider = embedder_config.get("provider")
|
||||
provider = embedder_config.get("provider", "")
|
||||
config = embedder_config.get("config", {})
|
||||
model_name = config.get("model")
|
||||
|
||||
@@ -38,12 +38,13 @@ class EmbeddingConfigurator:
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||
|
||||
if provider not in self.embedding_functions:
|
||||
embedding_function = self.embedding_functions.get(provider, None)
|
||||
if not embedding_function:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||
)
|
||||
|
||||
return self.embedding_functions[provider](config, model_name)
|
||||
return embedding_function(config, model_name)
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
|
||||
Reference in New Issue
Block a user