Files
crewAI/src/crewai/rag/chromadb/utils.py
Greyson LaLonde ec1eff02a8
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
fix: achieve parity between rag package and current impl (#3418)
- Sanitize ChromaDB collection names and use original dir naming
- Add persistent client with file locking to the ChromaDB factory
- Add upsert support to the ChromaDB client
- Suppress ChromaDB deprecation warnings for `model_fields`
- Extract `suppress_logging` into shared `logger_utils`
- Update tests to reflect upsert behavior
- Docs: add additional note
2025-08-28 11:22:36 -04:00

281 lines
8.2 KiB
Python

"""Utility functions for ChromaDB client implementation."""
import hashlib
from collections.abc import Mapping
from typing import Literal, TypeGuard, cast
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.types import (
Include,
IncludeEnum,
QueryResult,
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from crewai.rag.chromadb.constants import (
DEFAULT_COLLECTION,
INVALID_CHARS_PATTERN,
IPV4_PATTERN,
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
)
from crewai.rag.chromadb.types import (
ChromaDBClientType,
ChromaDBCollectionSearchParams,
ExtractedSearchParams,
PreparedDocuments,
)
from crewai.rag.types import BaseRecord, SearchResult
def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]:
"""Type guard to check if the client is a synchronous ClientAPI.
Args:
client: The client to check.
Returns:
True if the client is a ClientAPI, False otherwise.
"""
return isinstance(client, ClientAPI)
def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]:
"""Type guard to check if the client is an asynchronous AsyncClientAPI.
Args:
client: The client to check.
Returns:
True if the client is an AsyncClientAPI, False otherwise.
"""
return isinstance(client, AsyncClientAPI)
def _prepare_documents_for_chromadb(
documents: list[BaseRecord],
) -> PreparedDocuments:
"""Prepare documents for ChromaDB by extracting IDs, texts, and metadata.
Args:
documents: List of BaseRecord documents to prepare.
Returns:
PreparedDocuments with ids, texts, and metadatas ready for ChromaDB.
"""
ids: list[str] = []
texts: list[str] = []
metadatas: list[Mapping[str, str | int | float | bool]] = []
for doc in documents:
if "doc_id" in doc:
ids.append(doc["doc_id"])
else:
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
ids.append(content_hash)
texts.append(doc["content"])
metadata = doc.get("metadata")
if metadata:
if isinstance(metadata, list):
metadatas.append(metadata[0] if metadata else {})
else:
metadatas.append(metadata)
else:
metadatas.append({})
return PreparedDocuments(ids, texts, metadatas)
def _extract_search_params(
kwargs: ChromaDBCollectionSearchParams,
) -> ExtractedSearchParams:
"""Extract search parameters from kwargs.
Args:
kwargs: Keyword arguments containing search parameters.
Returns:
ExtractedSearchParams with all extracted parameters.
"""
return ExtractedSearchParams(
collection_name=kwargs["collection_name"],
query=kwargs["query"],
limit=kwargs.get("limit", 10),
metadata_filter=kwargs.get("metadata_filter"),
score_threshold=kwargs.get("score_threshold"),
where=kwargs.get("where"),
where_document=kwargs.get("where_document"),
include=kwargs.get(
"include",
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
),
)
def _convert_distance_to_score(
distance: float,
distance_metric: Literal["l2", "cosine", "ip"],
) -> float:
"""Convert ChromaDB distance to similarity score.
Notes:
Assuming all embedding are unit-normalized for now, including custom embeddings.
Args:
distance: The distance value from ChromaDB.
distance_metric: The distance metric used ("l2", "cosine", or "ip").
Returns:
Similarity score in range [0, 1] where 1 is most similar.
"""
if distance_metric == "cosine":
score = 1.0 - 0.5 * distance
return max(0.0, min(1.0, score))
raise ValueError(f"Unsupported distance metric: {distance_metric}")
def _convert_chromadb_results_to_search_results(
results: QueryResult,
include: Include,
distance_metric: Literal["l2", "cosine", "ip"],
score_threshold: float | None = None,
) -> list[SearchResult]:
"""Convert ChromaDB query results to SearchResult format.
Args:
results: ChromaDB query results.
include: List of fields that were included in the query.
distance_metric: The distance metric used by the collection.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
"""
search_results: list[SearchResult] = []
include_strings = [item.value for item in include]
ids = results["ids"][0] if results.get("ids") else []
documents_list = results.get("documents")
documents = (
documents_list[0] if documents_list and "documents" in include_strings else []
)
metadatas_list = results.get("metadatas")
metadatas = (
metadatas_list[0] if metadatas_list and "metadatas" in include_strings else []
)
distances_list = results.get("distances")
distances = (
distances_list[0] if distances_list and "distances" in include_strings else []
)
for i, doc_id in enumerate(ids):
if not distances or i >= len(distances):
continue
distance = distances[i]
score = _convert_distance_to_score(
distance=distance, distance_metric=distance_metric
)
if score_threshold and score < score_threshold:
continue
result: SearchResult = {
"id": doc_id,
"content": documents[i] if documents and i < len(documents) else "",
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
"score": score,
}
search_results.append(result)
return search_results
def _process_query_results(
collection: Collection | AsyncCollection,
results: QueryResult,
params: ExtractedSearchParams,
) -> list[SearchResult]:
"""Process ChromaDB query results and convert to SearchResult format.
Args:
collection: The ChromaDB collection (sync or async) that was queried.
results: Raw query results from ChromaDB.
params: The search parameters used for the query.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
"""
distance_metric = cast(
Literal["l2", "cosine", "ip"],
collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2",
)
return _convert_chromadb_results_to_search_results(
results=results,
include=params.include,
distance_metric=distance_metric,
score_threshold=params.score_threshold,
)
def _is_ipv4_pattern(name: str) -> bool:
"""Check if a string matches an IPv4 address pattern.
Args:
name: The string to check
Returns:
True if the string matches an IPv4 pattern, False otherwise
"""
return bool(IPV4_PATTERN.match(name))
def _sanitize_collection_name(
name: str | None, max_collection_length: int = MAX_COLLECTION_LENGTH
) -> str:
"""Sanitize a collection name to meet ChromaDB requirements.
Requirements:
1. 3-63 characters long
2. Starts and ends with alphanumeric character
3. Contains only alphanumeric characters, underscores, or hyphens
4. No consecutive periods
5. Not a valid IPv4 address
Args:
name: The original collection name to sanitize
max_collection_length: Maximum allowed length for the collection name
Returns:
A sanitized collection name that meets ChromaDB requirements
"""
if not name:
return DEFAULT_COLLECTION
if _is_ipv4_pattern(name):
name = f"ip_{name}"
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
if not sanitized[0].isalnum():
sanitized = "a" + sanitized
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized = sanitized + "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
return sanitized