mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
- update imports and include handling for chromadb v1.1.0 - fix mypy and typing_compat issues (required, typeddict, voyageai) - refine embedderconfig typing and allow base provider instances - handle mem0 as special case for external memory storage - bump tools and clean up redundant deps
324 lines
9.6 KiB
Python
324 lines
9.6 KiB
Python
"""Utility functions for ChromaDB client implementation."""
|
|
|
|
import hashlib
|
|
import json
|
|
from collections.abc import Mapping
|
|
from typing import Literal, TypeGuard, cast
|
|
|
|
from chromadb.api import AsyncClientAPI, ClientAPI
|
|
from chromadb.api.models.AsyncCollection import AsyncCollection
|
|
from chromadb.api.models.Collection import Collection
|
|
from chromadb.api.types import (
|
|
Include,
|
|
QueryResult,
|
|
)
|
|
|
|
from crewai.rag.chromadb.constants import (
|
|
DEFAULT_COLLECTION,
|
|
INVALID_CHARS_PATTERN,
|
|
IPV4_PATTERN,
|
|
MAX_COLLECTION_LENGTH,
|
|
MIN_COLLECTION_LENGTH,
|
|
)
|
|
from crewai.rag.chromadb.types import (
|
|
ChromaDBClientType,
|
|
ChromaDBCollectionSearchParams,
|
|
ExtractedSearchParams,
|
|
PreparedDocuments,
|
|
)
|
|
from crewai.rag.types import BaseRecord, SearchResult
|
|
|
|
|
|
def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]:
|
|
"""Type guard to check if the client is a synchronous ClientAPI.
|
|
|
|
Args:
|
|
client: The client to check.
|
|
|
|
Returns:
|
|
True if the client is a ClientAPI, False otherwise.
|
|
"""
|
|
return isinstance(client, ClientAPI)
|
|
|
|
|
|
def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]:
|
|
"""Type guard to check if the client is an asynchronous AsyncClientAPI.
|
|
|
|
Args:
|
|
client: The client to check.
|
|
|
|
Returns:
|
|
True if the client is an AsyncClientAPI, False otherwise.
|
|
"""
|
|
return isinstance(client, AsyncClientAPI)
|
|
|
|
|
|
def _prepare_documents_for_chromadb(
|
|
documents: list[BaseRecord],
|
|
) -> PreparedDocuments:
|
|
"""Prepare documents for ChromaDB by extracting IDs, texts, and metadata.
|
|
|
|
Args:
|
|
documents: List of BaseRecord documents to prepare.
|
|
|
|
Returns:
|
|
PreparedDocuments with ids, texts, and metadatas ready for ChromaDB.
|
|
"""
|
|
ids: list[str] = []
|
|
texts: list[str] = []
|
|
metadatas: list[Mapping[str, str | int | float | bool]] = []
|
|
|
|
for doc in documents:
|
|
if "doc_id" in doc:
|
|
ids.append(doc["doc_id"])
|
|
else:
|
|
content_for_hash = doc["content"]
|
|
metadata = doc.get("metadata")
|
|
if metadata:
|
|
metadata_str = json.dumps(metadata, sort_keys=True)
|
|
content_for_hash = f"{content_for_hash}|{metadata_str}"
|
|
|
|
content_hash = hashlib.blake2b(
|
|
content_for_hash.encode(), digest_size=32
|
|
).hexdigest()
|
|
ids.append(content_hash)
|
|
|
|
texts.append(doc["content"])
|
|
metadata = doc.get("metadata")
|
|
if metadata:
|
|
if isinstance(metadata, list):
|
|
metadatas.append(metadata[0] if metadata and metadata[0] else {})
|
|
else:
|
|
metadatas.append(metadata)
|
|
else:
|
|
metadatas.append({})
|
|
|
|
return PreparedDocuments(ids, texts, metadatas)
|
|
|
|
|
|
def _create_batch_slice(
|
|
prepared: PreparedDocuments, start_index: int, batch_size: int
|
|
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
|
|
"""Create a batch slice from prepared documents.
|
|
|
|
Args:
|
|
prepared: PreparedDocuments containing ids, texts, and metadatas.
|
|
start_index: Starting index for the batch.
|
|
batch_size: Size of the batch.
|
|
|
|
Returns:
|
|
Tuple of (batch_ids, batch_texts, batch_metadatas).
|
|
"""
|
|
batch_end = min(start_index + batch_size, len(prepared.ids))
|
|
batch_ids = prepared.ids[start_index:batch_end]
|
|
batch_texts = prepared.texts[start_index:batch_end]
|
|
batch_metadatas = (
|
|
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
|
|
)
|
|
|
|
if batch_metadatas and not any(m for m in batch_metadatas):
|
|
batch_metadatas = None
|
|
|
|
return batch_ids, batch_texts, batch_metadatas
|
|
|
|
|
|
def _extract_search_params(
|
|
kwargs: ChromaDBCollectionSearchParams,
|
|
) -> ExtractedSearchParams:
|
|
"""Extract search parameters from kwargs.
|
|
|
|
Args:
|
|
kwargs: Keyword arguments containing search parameters.
|
|
|
|
Returns:
|
|
ExtractedSearchParams with all extracted parameters.
|
|
"""
|
|
return ExtractedSearchParams(
|
|
collection_name=kwargs["collection_name"],
|
|
query=kwargs["query"],
|
|
limit=kwargs.get("limit", 10),
|
|
metadata_filter=kwargs.get("metadata_filter"),
|
|
score_threshold=kwargs.get("score_threshold"),
|
|
where=kwargs.get("where"),
|
|
where_document=kwargs.get("where_document"),
|
|
include=cast(
|
|
Include,
|
|
kwargs.get(
|
|
"include",
|
|
["metadatas", "documents", "distances"],
|
|
),
|
|
),
|
|
)
|
|
|
|
|
|
def _convert_distance_to_score(
|
|
distance: float,
|
|
distance_metric: Literal["l2", "cosine", "ip"],
|
|
) -> float:
|
|
"""Convert ChromaDB distance to similarity score.
|
|
|
|
Notes:
|
|
Assuming all embedding are unit-normalized for now, including custom embeddings.
|
|
|
|
Args:
|
|
distance: The distance value from ChromaDB.
|
|
distance_metric: The distance metric used ("l2", "cosine", or "ip").
|
|
|
|
Returns:
|
|
Similarity score in range [0, 1] where 1 is most similar.
|
|
"""
|
|
if distance_metric == "cosine":
|
|
score = 1.0 - 0.5 * distance
|
|
return max(0.0, min(1.0, score))
|
|
if distance_metric == "l2":
|
|
score = 1.0 / (1.0 + distance)
|
|
return max(0.0, min(1.0, score))
|
|
raise ValueError(f"Unsupported distance metric: {distance_metric}")
|
|
|
|
|
|
def _convert_chromadb_results_to_search_results(
|
|
results: QueryResult,
|
|
include: Include,
|
|
distance_metric: Literal["l2", "cosine", "ip"],
|
|
score_threshold: float | None = None,
|
|
) -> list[SearchResult]:
|
|
"""Convert ChromaDB query results to SearchResult format.
|
|
|
|
Args:
|
|
results: ChromaDB query results.
|
|
include: List of fields that were included in the query.
|
|
distance_metric: The distance metric used by the collection.
|
|
score_threshold: Optional minimum similarity score (0-1) for results.
|
|
|
|
Returns:
|
|
List of SearchResult dicts containing id, content, metadata, and score.
|
|
"""
|
|
search_results: list[SearchResult] = []
|
|
|
|
include_strings = list(include) if include else []
|
|
|
|
ids = results["ids"][0] if results.get("ids") else []
|
|
|
|
documents_list = results.get("documents")
|
|
documents = (
|
|
documents_list[0] if documents_list and "documents" in include_strings else []
|
|
)
|
|
|
|
metadatas_list = results.get("metadatas")
|
|
metadatas = (
|
|
metadatas_list[0] if metadatas_list and "metadatas" in include_strings else []
|
|
)
|
|
|
|
distances_list = results.get("distances")
|
|
distances = (
|
|
distances_list[0] if distances_list and "distances" in include_strings else []
|
|
)
|
|
|
|
for i, doc_id in enumerate(ids):
|
|
if not distances or i >= len(distances):
|
|
continue
|
|
|
|
distance = distances[i]
|
|
score = _convert_distance_to_score(
|
|
distance=distance, distance_metric=distance_metric
|
|
)
|
|
|
|
if score_threshold and score < score_threshold:
|
|
continue
|
|
|
|
result: SearchResult = {
|
|
"id": doc_id,
|
|
"content": documents[i] if documents and i < len(documents) else "",
|
|
"metadata": dict(metadatas[i])
|
|
if metadatas and i < len(metadatas) and metadatas[i] is not None
|
|
else {},
|
|
"score": score,
|
|
}
|
|
search_results.append(result)
|
|
|
|
return search_results
|
|
|
|
|
|
def _process_query_results(
|
|
collection: Collection | AsyncCollection,
|
|
results: QueryResult,
|
|
params: ExtractedSearchParams,
|
|
) -> list[SearchResult]:
|
|
"""Process ChromaDB query results and convert to SearchResult format.
|
|
|
|
Args:
|
|
collection: The ChromaDB collection (sync or async) that was queried.
|
|
results: Raw query results from ChromaDB.
|
|
params: The search parameters used for the query.
|
|
|
|
Returns:
|
|
List of SearchResult dicts containing id, content, metadata, and score.
|
|
"""
|
|
|
|
distance_metric = cast(
|
|
Literal["l2", "cosine", "ip"],
|
|
collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2",
|
|
)
|
|
|
|
return _convert_chromadb_results_to_search_results(
|
|
results=results,
|
|
include=params.include,
|
|
distance_metric=distance_metric,
|
|
score_threshold=params.score_threshold,
|
|
)
|
|
|
|
|
|
def _is_ipv4_pattern(name: str) -> bool:
|
|
"""Check if a string matches an IPv4 address pattern.
|
|
|
|
Args:
|
|
name: The string to check
|
|
|
|
Returns:
|
|
True if the string matches an IPv4 pattern, False otherwise
|
|
"""
|
|
return bool(IPV4_PATTERN.match(name))
|
|
|
|
|
|
def _sanitize_collection_name(
|
|
name: str | None, max_collection_length: int = MAX_COLLECTION_LENGTH
|
|
) -> str:
|
|
"""Sanitize a collection name to meet ChromaDB requirements.
|
|
|
|
Requirements:
|
|
1. 3-63 characters long
|
|
2. Starts and ends with alphanumeric character
|
|
3. Contains only alphanumeric characters, underscores, or hyphens
|
|
4. No consecutive periods
|
|
5. Not a valid IPv4 address
|
|
|
|
Args:
|
|
name: The original collection name to sanitize
|
|
max_collection_length: Maximum allowed length for the collection name
|
|
|
|
Returns:
|
|
A sanitized collection name that meets ChromaDB requirements
|
|
"""
|
|
if not name:
|
|
return DEFAULT_COLLECTION
|
|
|
|
if _is_ipv4_pattern(name):
|
|
name = f"ip_{name}"
|
|
|
|
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
|
|
|
|
if not sanitized[0].isalnum():
|
|
sanitized = "a" + sanitized
|
|
|
|
if not sanitized[-1].isalnum():
|
|
sanitized = sanitized[:-1] + "z"
|
|
|
|
if len(sanitized) < MIN_COLLECTION_LENGTH:
|
|
sanitized += "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
|
if len(sanitized) > max_collection_length:
|
|
sanitized = sanitized[:max_collection_length]
|
|
if not sanitized[-1].isalnum():
|
|
sanitized = sanitized[:-1] + "z"
|
|
|
|
return sanitized
|