Files
crewAI/src/crewai/memory/storage/rag_storage.py
2025-02-09 19:43:24 +00:00

231 lines
8.1 KiB
Python

import contextlib
import io
import logging
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional
from chromadb.api import ClientAPI, Collection
from chromadb.api.types import Documents, Embeddings, Metadatas
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
from crewai.utilities import EmbeddingConfigurator
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
from crewai.utilities.paths import db_storage_path
from crewai.utilities.exceptions.embedding_exceptions import (
EmbeddingConfigurationError,
EmbeddingInitializationError
)
@contextlib.contextmanager
def suppress_logging(
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
level=logging.ERROR,
):
logger = logging.getLogger(logger_name)
original_level = logger.getEffectiveLevel()
logger.setLevel(level)
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
contextlib.suppress(UserWarning),
):
yield
logger.setLevel(original_level)
class RAGStorage(BaseRAGStorage):
"""RAG-based Storage implementation using ChromaDB for vector storage and retrieval.
This class extends BaseRAGStorage to handle embeddings for memory entries,
improving search efficiency through vector similarity.
Attributes:
app: ChromaDB client instance
collection: ChromaDB collection for storing embeddings
type: Type of memory storage
allow_reset: Whether memory reset is allowed
path: Custom storage path for the database
"""
app: ClientAPI | None = None
collection: Any = None
def __init__(
self, type: str, allow_reset: bool = True, embedder_config: Dict[str, Any] | None = None, crew: Any = None, path: str | None = None
):
super().__init__(type, allow_reset, embedder_config, crew)
agents = crew.agents if crew else []
agents = [self._sanitize_role(agent.role) for agent in agents]
agents = "_".join(agents)
self.agents = agents
self.storage_file_name = self._build_storage_file_name(type, agents)
self.type = type
self.allow_reset = allow_reset
self.path = path
self._initialize_app()
def _set_embedder_config(self):
configurator = EmbeddingConfigurator()
self.embedder_config = configurator.configure_embedder(self.embedder_config)
def _initialize_app(self) -> None:
"""Initialize the ChromaDB client and collection.
Raises:
RuntimeError: If ChromaDB client initialization fails
EmbeddingConfigurationError: If embedding configuration is invalid
EmbeddingInitializationError: If embedding function fails to initialize
"""
import chromadb
from chromadb.config import Settings
self._set_embedder_config()
try:
self.app = chromadb.PersistentClient(
path=self.path if self.path else self.storage_file_name,
settings=Settings(allow_reset=self.allow_reset),
)
if not self.app:
raise RuntimeError("Failed to initialize ChromaDB client")
try:
self.collection = self.app.get_collection(
name=self.type, embedding_function=self.embedder_config
)
except Exception:
self.collection = self.app.create_collection(
name=self.type, embedding_function=self.embedder_config
)
except Exception as e:
raise RuntimeError(f"Failed to initialize ChromaDB: {str(e)}")
def _sanitize_role(self, role: str) -> str:
"""
Sanitizes agent roles to ensure valid directory names.
"""
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
def _build_storage_file_name(self, type: str, file_name: str) -> str:
"""
Ensures file name does not exceed max allowed by OS
"""
base_path = f"{db_storage_path()}/{type}"
if len(file_name) > MAX_FILE_NAME_LENGTH:
logging.warning(
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
)
file_name = file_name[:MAX_FILE_NAME_LENGTH]
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
"""Save a value with metadata to the memory storage.
Args:
value: The text content to store
metadata: Additional metadata for the stored content
Raises:
EmbeddingInitializationError: If embedding generation fails
"""
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
try:
self._generate_embedding(value, metadata)
except Exception as e:
raise EmbeddingInitializationError(self.type, str(e))
def search(
self,
query: str,
limit: int = 3,
filter: Optional[dict] = None,
score_threshold: float = 0.35,
) -> List[Dict[str, Any]]:
"""Search for similar content in memory.
Args:
query: The search query text
limit: Maximum number of results to return
filter: Optional filter criteria
score_threshold: Minimum similarity score threshold
Returns:
List of matching results with metadata and scores
"""
if not hasattr(self, "app"):
self._initialize_app()
try:
with suppress_logging():
response = self.collection.query(query_texts=query, n_results=limit)
results = []
for i in range(len(response["ids"][0])):
result = {
"id": response["ids"][0][i],
"metadata": response["metadatas"][0][i],
"context": response["documents"][0][i],
"score": response["distances"][0][i],
}
if result["score"] >= score_threshold:
results.append(result)
return results
except Exception as e:
logging.error(f"Error during {self.type} search: {str(e)}")
return []
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
"""Generate and store embeddings for the given text.
Args:
text: The text to generate embeddings for
metadata: Optional additional metadata to store with the embeddings
Returns:
Any: The generated embedding or None if only storing
"""
if not hasattr(self, "app") or not hasattr(self, "collection"):
self._initialize_app()
try:
self.collection.add(
documents=[text],
metadatas=[metadata or {}],
ids=[str(uuid.uuid4())],
)
return None
except Exception as e:
raise EmbeddingInitializationError(self.type, f"Failed to generate embedding: {str(e)}")
def reset(self) -> None:
"""Reset the memory storage by clearing the database and removing files.
Raises:
RuntimeError: If memory reset fails and allow_reset is False
EmbeddingConfigurationError: If embedding configuration is invalid during reinitialization
"""
try:
if self.app:
self.app.reset()
storage_path = self.path if self.path else db_storage_path()
db_dir = os.path.join(storage_path, self.type)
if os.path.exists(db_dir):
shutil.rmtree(db_dir)
self.app = None
self.collection = None
except Exception as e:
if "attempt to write a readonly database" in str(e):
# Ignore this specific error as it's expected in some environments
pass
else:
if not self.allow_reset:
raise RuntimeError(f"Failed to reset {self.type} memory: {str(e)}")
logging.error(f"Error during {self.type} memory reset: {str(e)}")