mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
231 lines
8.1 KiB
Python
231 lines
8.1 KiB
Python
import contextlib
|
|
import io
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from chromadb.api import ClientAPI, Collection
|
|
from chromadb.api.types import Documents, Embeddings, Metadatas
|
|
|
|
from crewai.memory.storage.base_rag_storage import BaseRAGStorage
|
|
from crewai.utilities import EmbeddingConfigurator
|
|
from crewai.utilities.constants import MAX_FILE_NAME_LENGTH
|
|
from crewai.utilities.paths import db_storage_path
|
|
from crewai.utilities.exceptions.embedding_exceptions import (
|
|
EmbeddingConfigurationError,
|
|
EmbeddingInitializationError
|
|
)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def suppress_logging(
|
|
logger_name="chromadb.segment.impl.vector.local_persistent_hnsw",
|
|
level=logging.ERROR,
|
|
):
|
|
logger = logging.getLogger(logger_name)
|
|
original_level = logger.getEffectiveLevel()
|
|
logger.setLevel(level)
|
|
with (
|
|
contextlib.redirect_stdout(io.StringIO()),
|
|
contextlib.redirect_stderr(io.StringIO()),
|
|
contextlib.suppress(UserWarning),
|
|
):
|
|
yield
|
|
logger.setLevel(original_level)
|
|
|
|
|
|
class RAGStorage(BaseRAGStorage):
|
|
"""RAG-based Storage implementation using ChromaDB for vector storage and retrieval.
|
|
|
|
This class extends BaseRAGStorage to handle embeddings for memory entries,
|
|
improving search efficiency through vector similarity.
|
|
|
|
Attributes:
|
|
app: ChromaDB client instance
|
|
collection: ChromaDB collection for storing embeddings
|
|
type: Type of memory storage
|
|
allow_reset: Whether memory reset is allowed
|
|
path: Custom storage path for the database
|
|
"""
|
|
|
|
app: ClientAPI | None = None
|
|
collection: Any = None
|
|
|
|
def __init__(
|
|
self, type: str, allow_reset: bool = True, embedder_config: Dict[str, Any] | None = None, crew: Any = None, path: str | None = None
|
|
):
|
|
super().__init__(type, allow_reset, embedder_config, crew)
|
|
agents = crew.agents if crew else []
|
|
agents = [self._sanitize_role(agent.role) for agent in agents]
|
|
agents = "_".join(agents)
|
|
self.agents = agents
|
|
self.storage_file_name = self._build_storage_file_name(type, agents)
|
|
|
|
self.type = type
|
|
self.allow_reset = allow_reset
|
|
self.path = path
|
|
self._initialize_app()
|
|
|
|
def _set_embedder_config(self):
|
|
configurator = EmbeddingConfigurator()
|
|
self.embedder_config = configurator.configure_embedder(self.embedder_config)
|
|
|
|
def _initialize_app(self) -> None:
|
|
"""Initialize the ChromaDB client and collection.
|
|
|
|
Raises:
|
|
RuntimeError: If ChromaDB client initialization fails
|
|
EmbeddingConfigurationError: If embedding configuration is invalid
|
|
EmbeddingInitializationError: If embedding function fails to initialize
|
|
"""
|
|
import chromadb
|
|
from chromadb.config import Settings
|
|
|
|
self._set_embedder_config()
|
|
try:
|
|
self.app = chromadb.PersistentClient(
|
|
path=self.path if self.path else self.storage_file_name,
|
|
settings=Settings(allow_reset=self.allow_reset),
|
|
)
|
|
if not self.app:
|
|
raise RuntimeError("Failed to initialize ChromaDB client")
|
|
|
|
try:
|
|
self.collection = self.app.get_collection(
|
|
name=self.type, embedding_function=self.embedder_config
|
|
)
|
|
except Exception:
|
|
self.collection = self.app.create_collection(
|
|
name=self.type, embedding_function=self.embedder_config
|
|
)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to initialize ChromaDB: {str(e)}")
|
|
|
|
def _sanitize_role(self, role: str) -> str:
|
|
"""
|
|
Sanitizes agent roles to ensure valid directory names.
|
|
"""
|
|
return role.replace("\n", "").replace(" ", "_").replace("/", "_")
|
|
|
|
def _build_storage_file_name(self, type: str, file_name: str) -> str:
|
|
"""
|
|
Ensures file name does not exceed max allowed by OS
|
|
"""
|
|
base_path = f"{db_storage_path()}/{type}"
|
|
|
|
if len(file_name) > MAX_FILE_NAME_LENGTH:
|
|
logging.warning(
|
|
f"Trimming file name from {len(file_name)} to {MAX_FILE_NAME_LENGTH} characters."
|
|
)
|
|
file_name = file_name[:MAX_FILE_NAME_LENGTH]
|
|
|
|
return f"{base_path}/{file_name}"
|
|
|
|
def save(self, value: Any, metadata: Dict[str, Any]) -> None:
|
|
"""Save a value with metadata to the memory storage.
|
|
|
|
Args:
|
|
value: The text content to store
|
|
metadata: Additional metadata for the stored content
|
|
|
|
Raises:
|
|
EmbeddingInitializationError: If embedding generation fails
|
|
"""
|
|
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
|
self._initialize_app()
|
|
try:
|
|
self._generate_embedding(value, metadata)
|
|
except Exception as e:
|
|
raise EmbeddingInitializationError(self.type, str(e))
|
|
|
|
def search(
|
|
self,
|
|
query: str,
|
|
limit: int = 3,
|
|
filter: Optional[dict] = None,
|
|
score_threshold: float = 0.35,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Search for similar content in memory.
|
|
|
|
Args:
|
|
query: The search query text
|
|
limit: Maximum number of results to return
|
|
filter: Optional filter criteria
|
|
score_threshold: Minimum similarity score threshold
|
|
|
|
Returns:
|
|
List of matching results with metadata and scores
|
|
"""
|
|
if not hasattr(self, "app"):
|
|
self._initialize_app()
|
|
|
|
try:
|
|
with suppress_logging():
|
|
response = self.collection.query(query_texts=query, n_results=limit)
|
|
|
|
results = []
|
|
for i in range(len(response["ids"][0])):
|
|
result = {
|
|
"id": response["ids"][0][i],
|
|
"metadata": response["metadatas"][0][i],
|
|
"context": response["documents"][0][i],
|
|
"score": response["distances"][0][i],
|
|
}
|
|
if result["score"] >= score_threshold:
|
|
results.append(result)
|
|
|
|
return results
|
|
except Exception as e:
|
|
logging.error(f"Error during {self.type} search: {str(e)}")
|
|
return []
|
|
|
|
def _generate_embedding(self, text: str, metadata: Optional[Dict[str, Any]] = None) -> Any:
|
|
"""Generate and store embeddings for the given text.
|
|
|
|
Args:
|
|
text: The text to generate embeddings for
|
|
metadata: Optional additional metadata to store with the embeddings
|
|
|
|
Returns:
|
|
Any: The generated embedding or None if only storing
|
|
"""
|
|
if not hasattr(self, "app") or not hasattr(self, "collection"):
|
|
self._initialize_app()
|
|
|
|
try:
|
|
self.collection.add(
|
|
documents=[text],
|
|
metadatas=[metadata or {}],
|
|
ids=[str(uuid.uuid4())],
|
|
)
|
|
return None
|
|
except Exception as e:
|
|
raise EmbeddingInitializationError(self.type, f"Failed to generate embedding: {str(e)}")
|
|
|
|
def reset(self) -> None:
|
|
"""Reset the memory storage by clearing the database and removing files.
|
|
|
|
Raises:
|
|
RuntimeError: If memory reset fails and allow_reset is False
|
|
EmbeddingConfigurationError: If embedding configuration is invalid during reinitialization
|
|
"""
|
|
try:
|
|
if self.app:
|
|
self.app.reset()
|
|
storage_path = self.path if self.path else db_storage_path()
|
|
db_dir = os.path.join(storage_path, self.type)
|
|
if os.path.exists(db_dir):
|
|
shutil.rmtree(db_dir)
|
|
self.app = None
|
|
self.collection = None
|
|
except Exception as e:
|
|
if "attempt to write a readonly database" in str(e):
|
|
# Ignore this specific error as it's expected in some environments
|
|
pass
|
|
else:
|
|
if not self.allow_reset:
|
|
raise RuntimeError(f"Failed to reset {self.type} memory: {str(e)}")
|
|
logging.error(f"Error during {self.type} memory reset: {str(e)}")
|