mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
- introduce baseembeddingsprovider and helper for embedding functions - add core embedding types and migrate providers, factory, and storage modules - remove unused type aliases and fix pydantic schema error - update providers with env var support and related fixes
205 lines
7.2 KiB
Python
205 lines
7.2 KiB
Python
import logging
|
|
import traceback
|
|
import warnings
|
|
from typing import Any, cast
|
|
|
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
|
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
|
from crewai.rag.config.utils import get_rag_client
|
|
from crewai.rag.core.base_client import BaseClient
|
|
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
|
from crewai.rag.embeddings.factory import build_embedder
|
|
from crewai.rag.embeddings.types import ProviderSpec
|
|
from crewai.rag.factory import create_client
|
|
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
|
from crewai.rag.types import BaseRecord
|
|
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
|
from crewai.utilities.paths import db_storage_path
|
|
|
|
|
|
class RAGStorage(BaseRAGStorage):
|
|
"""
|
|
Extends Storage to handle embeddings for memory entries, improving
|
|
search efficiency.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
type: str,
|
|
allow_reset: bool = True,
|
|
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
|
crew: Any = None,
|
|
path: str | None = None,
|
|
) -> None:
|
|
super().__init__(type, allow_reset, embedder_config, crew)
|
|
agents = crew.agents if crew else []
|
|
agents = [self._sanitize_role(agent.role) for agent in agents]
|
|
agents = "_".join(agents)
|
|
self.agents = agents
|
|
self.storage_file_name = self._build_storage_file_name(type, agents)
|
|
|
|
self.type = type
|
|
self._client: BaseClient | None = None
|
|
|
|
self.allow_reset = allow_reset
|
|
self.path = path
|
|
|
|
warnings.filterwarnings(
|
|
"ignore",
|
|
message=r".*'model_fields'.*is deprecated.*",
|
|
module=r"^chromadb(\.|$)",
|
|
)
|
|
|
|
if self.embedder_config:
|
|
embedding_function = build_embedder(self.embedder_config)
|
|
|
|
try:
|
|
_ = embedding_function(["test"])
|
|
except Exception as e:
|
|
provider = (
|
|
self.embedder_config["provider"]
|
|
if isinstance(self.embedder_config, dict)
|
|
else self.embedder_config.__class__.__name__.replace(
|
|
"Provider", ""
|
|
).lower()
|
|
)
|
|
raise ValueError(
|
|
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
|
f"Provider: {provider}\n"
|
|
f"Error: {e}"
|
|
) from e
|
|
|
|
batch_size = None
|
|
if (
|
|
isinstance(self.embedder_config, dict)
|
|
and "config" in self.embedder_config
|
|
):
|
|
nested_config = self.embedder_config["config"]
|
|
if isinstance(nested_config, dict):
|
|
batch_size = nested_config.get("batch_size")
|
|
|
|
if batch_size is not None:
|
|
config = ChromaDBConfig(
|
|
embedding_function=cast(
|
|
ChromaEmbeddingFunctionWrapper, embedding_function
|
|
),
|
|
batch_size=cast(int, batch_size),
|
|
)
|
|
else:
|
|
config = ChromaDBConfig(
|
|
embedding_function=cast(
|
|
ChromaEmbeddingFunctionWrapper, embedding_function
|
|
)
|
|
)
|
|
self._client = create_client(config)
|
|
|
|
def _get_client(self) -> BaseClient:
|
|
"""Get the appropriate client - instance-specific or global."""
|
|
return self._client if self._client else get_rag_client()
|
|
|
|
def _sanitize_role(self, role: str) -> str:
|
|
"""
|
|
Sanitizes agent roles to ensure valid directory names.
|
|
"""
|
|
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
|
|
|
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
|
"""
|
|
Ensures file name does not exceed max allowed by OS
|
|
"""
|
|
base_path = f"{db_storage_path()}/{type}"
|
|
|
|
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
|
logging.warning(
|
|
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
|
)
|
|
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
|
|
|
return f"{base_path}/{file_name}"
|
|
|
|
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
|
try:
|
|
client = self._get_client()
|
|
collection_name = (
|
|
f"memory_{self.type}_{self.agents}"
|
|
if self.agents
|
|
else f"memory_{self.type}"
|
|
)
|
|
client.get_or_create_collection(collection_name=collection_name)
|
|
|
|
document: BaseRecord = {"content": value}
|
|
if metadata:
|
|
document["metadata"] = metadata
|
|
|
|
batch_size = None
|
|
if (
|
|
self.embedder_config
|
|
and isinstance(self.embedder_config, dict)
|
|
and "config" in self.embedder_config
|
|
):
|
|
nested_config = self.embedder_config["config"]
|
|
if isinstance(nested_config, dict):
|
|
batch_size = nested_config.get("batch_size")
|
|
|
|
if batch_size is not None:
|
|
client.add_documents(
|
|
collection_name=collection_name,
|
|
documents=[document],
|
|
batch_size=cast(int, batch_size),
|
|
)
|
|
else:
|
|
client.add_documents(
|
|
collection_name=collection_name, documents=[document]
|
|
)
|
|
except Exception as e:
|
|
logging.error(
|
|
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
|
)
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 5,
|
|
filter: dict[str, Any] | None = None,
|
|
score_threshold: float = 0.6,
|
|
) -> list[Any]:
|
|
try:
|
|
client = self._get_client()
|
|
collection_name = (
|
|
f"memory_{self.type}_{self.agents}"
|
|
if self.agents
|
|
else f"memory_{self.type}"
|
|
)
|
|
return client.search(
|
|
collection_name=collection_name,
|
|
query=query,
|
|
limit=limit,
|
|
metadata_filter=filter,
|
|
score_threshold=score_threshold,
|
|
)
|
|
except Exception as e:
|
|
logging.error(
|
|
f"Error during {self.type} search: {e!s}\n{traceback.format_exc()}"
|
|
)
|
|
return []
|
|
|
|
def reset(self) -> None:
|
|
try:
|
|
client = self._get_client()
|
|
collection_name = (
|
|
f"memory_{self.type}_{self.agents}"
|
|
if self.agents
|
|
else f"memory_{self.type}"
|
|
)
|
|
client.delete_collection(collection_name=collection_name)
|
|
except Exception as e:
|
|
if "attempt to write a readonly database" in str(
|
|
e
|
|
) or "does not exist" in str(e):
|
|
# Ignore readonly database and collection not found errors (already reset)
|
|
pass
|
|
else:
|
|
raise Exception(
|
|
f"An error occurred while resetting the {self.type} memory: {e}"
|
|
) from e
|