Files
crewAI/src/crewai/memory/storage/rag_storage.py
Greyson LaLonde ce5ea9be6f feat: add custom embedding types and migrate providers
- introduce baseembeddingsprovider and helper for embedding functions  
- add core embedding types and migrate providers, factory, and storage modules  
- remove unused type aliases and fix pydantic schema error  
- update providers with env var support and related fixes
2025-09-25 18:28:39 -04:00

205 lines
7.2 KiB
Python

import logging
import traceback
import warnings
from typing import Any, cast
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.types import ProviderSpec
from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.types import BaseRecord
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
class RAGStorage(BaseRAGStorage):
"""
Extends Storage to handle embeddings for memory entries, improving
search efficiency.
"""
def __init__(
self,
type: str,
allow_reset: bool = True,
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
crew: Any = None,
path: str | None = None,
) -> None:
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
self.agents = agents
self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type
self._client: BaseClient | None = None
self.allow_reset = allow_reset
self.path = path
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
if self.embedder_config:
embedding_function = build_embedder(self.embedder_config)
try:
_ = embedding_function(["test"])
except Exception as e:
provider = (
self.embedder_config["provider"]
if isinstance(self.embedder_config, dict)
else self.embedder_config.__class__.__name__.replace(
"Provider", ""
).lower()
)
raise ValueError(
f"Failed to initialize embedder. Please check your configuration or connection.\n"
f"Provider: {provider}\n"
f"Error: {e}"
) from e
batch_size = None
if (
isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
),
batch_size=cast(int, batch_size),
)
else:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global."""
return self._client if self._client else get_rag_client()
def _sanitize_role(self, role: str) -> str:
"""
Sanitizes agent roles to ensure valid directory names.
"""
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _build_storage_file_name(self, type: str, file_name: str) -> str:
"""
Ensures file name does not exceed max allowed by OS
"""
base_path = f"{db_storage_path()}/{type}"
if len(file_name) > MAX_FILE_NAME_LENGTH:
logging.warning(
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
)
file_name = file_name[:MAX_FILE_NAME_LENGTH]
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: dict[str, Any]) -> None:
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
client.get_or_create_collection(collection_name=collection_name)
document: BaseRecord = {"content": value}
if metadata:
document["metadata"] = metadata
batch_size = None
if (
self.embedder_config
and isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
client.add_documents(
collection_name=collection_name,
documents=[document],
batch_size=cast(int, batch_size),
)
else:
client.add_documents(
collection_name=collection_name, documents=[document]
)
except Exception as e:
logging.error(
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
)
def search(
self,
query: str,
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
return client.search(
collection_name=collection_name,
query=query,
limit=limit,
metadata_filter=filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
)
return []
def reset(self) -> None:
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
client.delete_collection(collection_name=collection_name)
except Exception as e:
if "attempt to write a readonly database" in str(
e
) or "does not exist" in str(e):
# Ignore readonly database and collection not found errors (already reset)
pass
else:
raise Exception(
f"An error occurred while resetting the {self.type} memory: {e}"
) from e