mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Improve handling of optional configurations in memory and storage
- Initialize contextual_memory in src/crewai/agent.py and src/crewai/crew.py - Make UserMemory optional and add checks in src/crewai/memory/contextual/contextual_memory.py - Add crew checks in src/crewai/memory/entity/entity_memory.py and src/crewai/memory/short_term/short_term_memory.py - Allow optional storage_path in src/crewai/memory/storage/base_rag_storage.py - Update storage classes to accept optional db_path in: src/crewai/memory/storage/kickoff_task_outputs_storage.py, src/crewai/memory/storage/ltm_sqlite_storage.py, and src/crewai/memory/storage/mem0_storage.py - Modify src/crewai/memory/storage/rag_storage.py to use storage_path - Enhance src/crewai/utilities/embedding_configurator.py to handle missing providers
This commit is contained in:
@@ -294,14 +294,7 @@ class Agent(BaseAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.crew and self.crew.memory:
|
if self.crew and self.crew.memory:
|
||||||
contextual_memory = ContextualMemory(
|
memory = self.crew.contextual_memory.build_context_for_task(task, context)
|
||||||
self.crew.memory_config,
|
|
||||||
self.crew._short_term_memory,
|
|
||||||
self.crew._long_term_memory,
|
|
||||||
self.crew._entity_memory,
|
|
||||||
self.crew._user_memory,
|
|
||||||
)
|
|
||||||
memory = contextual_memory.build_context_for_task(task, context)
|
|
||||||
if memory.strip() != "":
|
if memory.strip() != "":
|
||||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from crewai.crews.crew_output import CrewOutput
|
|||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM
|
||||||
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||||
from crewai.memory.entity.entity_memory import EntityMemory
|
from crewai.memory.entity.entity_memory import EntityMemory
|
||||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||||
@@ -278,6 +279,13 @@ class Crew(BaseModel):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._user_memory = None
|
self._user_memory = None
|
||||||
|
self.contextual_memory = ContextualMemory(
|
||||||
|
memory_config=self.memory_config,
|
||||||
|
stm=self._short_term_memory,
|
||||||
|
ltm=self._long_term_memory,
|
||||||
|
em=self._entity_memory,
|
||||||
|
um=self._user_memory,
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ class ContextualMemory:
|
|||||||
stm: ShortTermMemory,
|
stm: ShortTermMemory,
|
||||||
ltm: LongTermMemory,
|
ltm: LongTermMemory,
|
||||||
em: EntityMemory,
|
em: EntityMemory,
|
||||||
um: UserMemory,
|
um: Optional[UserMemory],
|
||||||
):
|
):
|
||||||
if memory_config is not None:
|
if memory_config is not None:
|
||||||
self.memory_provider = memory_config.get("provider")
|
self.memory_provider = memory_config.get("provider")
|
||||||
@@ -94,6 +94,8 @@ class ContextualMemory:
|
|||||||
Returns:
|
Returns:
|
||||||
str: Formatted user memories as bullet points, or an empty string if none found.
|
str: Formatted user memories as bullet points, or an empty string if none found.
|
||||||
"""
|
"""
|
||||||
|
if not self.um:
|
||||||
|
return ""
|
||||||
user_memories = self.um.search(query)
|
user_memories = self.um.search(query)
|
||||||
if not user_memories:
|
if not user_memories:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class EntityMemory(Memory):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||||
self.memory_provider = crew.memory_config.get("provider")
|
self.memory_provider = crew.memory_config.get("provider")
|
||||||
else:
|
else:
|
||||||
self.memory_provider = None
|
self.memory_provider = None
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class ShortTermMemory(Memory):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||||
if hasattr(crew, "memory_config") and crew.memory_config is not None:
|
if crew and hasattr(crew, "memory_config") and crew.memory_config is not None:
|
||||||
self.memory_provider = crew.memory_config.get("provider")
|
self.memory_provider = crew.memory_config.get("provider")
|
||||||
else:
|
else:
|
||||||
self.memory_provider = None
|
self.memory_provider = None
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
|
|
||||||
class BaseRAGStorage(ABC):
|
class BaseRAGStorage(ABC):
|
||||||
"""
|
"""
|
||||||
@@ -12,11 +15,13 @@ class BaseRAGStorage(ABC):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
type: str,
|
type: str,
|
||||||
|
storage_path: Optional[Path] = None,
|
||||||
allow_reset: bool = True,
|
allow_reset: bool = True,
|
||||||
embedder_config: Optional[Any] = None,
|
embedder_config: Optional[Any] = None,
|
||||||
crew: Any = None,
|
crew: Any = None,
|
||||||
):
|
):
|
||||||
self.type = type
|
self.type = type
|
||||||
|
self.storage_path = storage_path if storage_path else db_storage_path()
|
||||||
self.allow_reset = allow_reset
|
self.allow_reset = allow_reset
|
||||||
self.embedder_config = embedder_config
|
self.embedder_config = embedder_config
|
||||||
self.crew = crew
|
self.crew = crew
|
||||||
|
|||||||
@@ -13,10 +13,12 @@ class KickoffTaskOutputsSQLiteStorage:
|
|||||||
An updated SQLite storage class for kickoff task outputs storage.
|
An updated SQLite storage class for kickoff task outputs storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_path: Optional[str] = None) -> None:
|
||||||
self, db_path: str = f"{db_storage_path()}/latest_kickoff_task_outputs.db"
|
self.db_path = (
|
||||||
) -> None:
|
db_path
|
||||||
self.db_path = db_path
|
if db_path
|
||||||
|
else f"{db_storage_path()}/latest_kickoff_task_outputs.db"
|
||||||
|
)
|
||||||
self._printer: Printer = Printer()
|
self._printer: Printer = Printer()
|
||||||
self._initialize_db()
|
self._initialize_db()
|
||||||
|
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ class LTMSQLiteStorage:
|
|||||||
An updated SQLite storage class for LTM data storage.
|
An updated SQLite storage class for LTM data storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_path: Optional[str] = None) -> None:
|
||||||
self, db_path: str = f"{db_storage_path()}/long_term_memory_storage.db"
|
self.db_path = (
|
||||||
) -> None:
|
db_path if db_path else f"{db_storage_path()}/latest_long_term_memories.db"
|
||||||
self.db_path = db_path
|
)
|
||||||
self._printer: Printer = Printer()
|
self._printer: Printer = Printer()
|
||||||
self._initialize_db()
|
self._initialize_db()
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class Mem0Storage(Storage):
|
|||||||
|
|
||||||
self.memory_type = type
|
self.memory_type = type
|
||||||
self.crew = crew
|
self.crew = crew
|
||||||
self.memory_config = crew.memory_config
|
self.memory_config = crew.memory_config if crew else None
|
||||||
|
|
||||||
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
# User ID is required for user memory type "user" since it's used as a unique identifier for the user.
|
||||||
user_id = self._get_user_id()
|
user_id = self._get_user_id()
|
||||||
@@ -27,9 +27,10 @@ class Mem0Storage(Storage):
|
|||||||
raise ValueError("User ID is required for user memory type")
|
raise ValueError("User ID is required for user memory type")
|
||||||
|
|
||||||
# API key in memory config overrides the environment variable
|
# API key in memory config overrides the environment variable
|
||||||
mem0_api_key = self.memory_config.get("config", {}).get("api_key") or os.getenv(
|
if self.memory_config and self.memory_config.get("config"):
|
||||||
"MEM0_API_KEY"
|
mem0_api_key = self.memory_config.get("config").get("api_key")
|
||||||
)
|
else:
|
||||||
|
mem0_api_key = os.getenv("MEM0_API_KEY")
|
||||||
self.memory = MemoryClient(api_key=mem0_api_key)
|
self.memory = MemoryClient(api_key=mem0_api_key)
|
||||||
|
|
||||||
def _sanitize_role(self, role: str) -> str:
|
def _sanitize_role(self, role: str) -> str:
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from chromadb.api import ClientAPI
|
|||||||
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
||||||
from crewai.utilities import EmbeddingConfigurator
|
from crewai.utilities import EmbeddingConfigurator
|
||||||
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
||||||
from crewai.utilities.paths import db_storage_path
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -40,9 +39,15 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
app: ClientAPI | None = None
|
app: ClientAPI | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, type, allow_reset=True, embedder_config=None, crew=None, path=None
|
self,
|
||||||
|
type,
|
||||||
|
storage_path=None,
|
||||||
|
allow_reset=True,
|
||||||
|
embedder_config=None,
|
||||||
|
crew=None,
|
||||||
|
path=None,
|
||||||
):
|
):
|
||||||
super().__init__(type, allow_reset, embedder_config, crew)
|
super().__init__(type, storage_path, allow_reset, embedder_config, crew)
|
||||||
agents = crew.agents if crew else []
|
agents = crew.agents if crew else []
|
||||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||||
agents = "_".join(agents)
|
agents = "_".join(agents)
|
||||||
@@ -90,7 +95,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
"""
|
"""
|
||||||
Ensures file name does not exceed max allowed by OS
|
Ensures file name does not exceed max allowed by OS
|
||||||
"""
|
"""
|
||||||
base_path = f"{db_storage_path()}/{type}"
|
base_path = f"{self.storage_path}/{type}"
|
||||||
|
|
||||||
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@@ -152,7 +157,7 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
try:
|
try:
|
||||||
if self.app:
|
if self.app:
|
||||||
self.app.reset()
|
self.app.reset()
|
||||||
shutil.rmtree(f"{db_storage_path()}/{self.type}")
|
shutil.rmtree(f"{self.storage_path}/{self.type}")
|
||||||
self.app = None
|
self.app = None
|
||||||
self.collection = None
|
self.collection = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class EmbeddingConfigurator:
|
|||||||
if embedder_config is None:
|
if embedder_config is None:
|
||||||
return self._create_default_embedding_function()
|
return self._create_default_embedding_function()
|
||||||
|
|
||||||
provider = embedder_config.get("provider")
|
provider = embedder_config.get("provider", "")
|
||||||
config = embedder_config.get("config", {})
|
config = embedder_config.get("config", {})
|
||||||
model_name = config.get("model")
|
model_name = config.get("model")
|
||||||
|
|
||||||
@@ -38,12 +38,13 @@ class EmbeddingConfigurator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
||||||
|
|
||||||
if provider not in self.embedding_functions:
|
embedding_function = self.embedding_functions.get(provider, None)
|
||||||
|
if not embedding_function:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.embedding_functions[provider](config, model_name)
|
return embedding_function(config, model_name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_default_embedding_function():
|
def _create_default_embedding_function():
|
||||||
|
|||||||
Reference in New Issue
Block a user