From 842bed4e9caf7b9751ba3173d6b3e84d23523242 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 21 Aug 2025 18:18:46 -0400 Subject: [PATCH] feat: chromadb generic client (#3374) Add ChromaDB client implementation with async support - Implement core collection operations (create, get_or_create, delete) - Add search functionality with cosine similarity scoring - Include both sync and async method variants - Add type safety with NamedTuples and TypeGuards - Extract utility functions to separate modules - Default to cosine distance metric for text similarity - Add comprehensive test coverage TODO: - l2, ip score calculations are not settled on --- src/crewai/rag/chromadb/__init__.py | 0 src/crewai/rag/chromadb/client.py | 556 ++++++++++++++++++++++++++++ src/crewai/rag/chromadb/types.py | 85 +++++ src/crewai/rag/chromadb/utils.py | 220 +++++++++++ tests/rag/__init__.py | 0 tests/rag/chromadb/__init__.py | 0 tests/rag/chromadb/test_client.py | 550 +++++++++++++++++++++++++++ 7 files changed, 1411 insertions(+) create mode 100644 src/crewai/rag/chromadb/__init__.py create mode 100644 src/crewai/rag/chromadb/client.py create mode 100644 src/crewai/rag/chromadb/types.py create mode 100644 src/crewai/rag/chromadb/utils.py create mode 100644 tests/rag/__init__.py create mode 100644 tests/rag/chromadb/__init__.py create mode 100644 tests/rag/chromadb/test_client.py diff --git a/src/crewai/rag/chromadb/__init__.py b/src/crewai/rag/chromadb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/crewai/rag/chromadb/client.py b/src/crewai/rag/chromadb/client.py new file mode 100644 index 000000000..ea67cd2fb --- /dev/null +++ b/src/crewai/rag/chromadb/client.py @@ -0,0 +1,556 @@ +"""ChromaDB client implementation.""" + +from typing import Any + +from chromadb.api.types import ( + Embeddable, + EmbeddingFunction as ChromaEmbeddingFunction, + QueryResult, +) +from typing_extensions import Unpack + +from crewai.rag.chromadb.types import ( + ChromaDBClientType, + ChromaDBCollectionCreateParams, + ChromaDBCollectionSearchParams, +) +from crewai.rag.chromadb.utils import ( + _extract_search_params, + _is_async_client, + _is_sync_client, + _prepare_documents_for_chromadb, + _process_query_results, +) +from crewai.rag.core.base_client import ( + BaseClient, + BaseCollectionParams, + BaseCollectionAddParams, +) +from crewai.rag.types import SearchResult + + +class ChromaDBClient(BaseClient): + """ChromaDB implementation of the BaseClient protocol. + + Provides vector database operations for ChromaDB, supporting both + synchronous and asynchronous clients. + + Attributes: + client: ChromaDB client instance (ClientAPI or AsyncClientAPI). + embedding_function: Function to generate embeddings for documents. + """ + + client: ChromaDBClientType + embedding_function: ChromaEmbeddingFunction[Embeddable] + + def create_collection( + self, **kwargs: Unpack[ChromaDBCollectionCreateParams] + ) -> None: + """Create a new collection in ChromaDB. + + Uses the client's default embedding function if none provided. + + Keyword Args: + collection_name: Name of the collection to create. Must be unique. + configuration: Optional collection configuration specifying distance metrics, + HNSW parameters, or other backend-specific settings. + metadata: Optional metadata dictionary to attach to the collection. + embedding_function: Optional custom embedding function. If not provided, + uses the client's default embedding function. + data_loader: Optional data loader for batch loading data into the collection. + get_or_create: If True, returns existing collection if it already exists + instead of raising an error. Defaults to False. + + Raises: + TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations. + ValueError: If collection with the same name already exists and get_or_create + is False. + ConnectionError: If unable to connect to ChromaDB server. + + Example: + >>> client = ChromaDBClient() + >>> client.create_collection( + ... collection_name="documents", + ... metadata={"description": "Product documentation"}, + ... get_or_create=True + ... ) + """ + if not _is_sync_client(self.client): + raise TypeError( + "Synchronous method create_collection() requires a ClientAPI. " + "Use acreate_collection() for AsyncClientAPI." + ) + + metadata = kwargs.get("metadata", {}) + if "hnsw:space" not in metadata: + metadata["hnsw:space"] = "cosine" + + self.client.create_collection( + name=kwargs["collection_name"], + configuration=kwargs.get("configuration"), + metadata=metadata, + embedding_function=kwargs.get( + "embedding_function", self.embedding_function + ), + data_loader=kwargs.get("data_loader"), + get_or_create=kwargs.get("get_or_create", False), + ) + + async def acreate_collection( + self, **kwargs: Unpack[ChromaDBCollectionCreateParams] + ) -> None: + """Create a new collection in ChromaDB asynchronously. + + Creates a new collection with the specified name and optional configuration. + If an embedding function is not provided, uses the client's default embedding function. + + Keyword Args: + collection_name: Name of the collection to create. Must be unique. + configuration: Optional collection configuration specifying distance metrics, + HNSW parameters, or other backend-specific settings. + metadata: Optional metadata dictionary to attach to the collection. + embedding_function: Optional custom embedding function. If not provided, + uses the client's default embedding function. + data_loader: Optional data loader for batch loading data into the collection. + get_or_create: If True, returns existing collection if it already exists + instead of raising an error. Defaults to False. + + Raises: + TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations. + ValueError: If collection with the same name already exists and get_or_create + is False. + ConnectionError: If unable to connect to ChromaDB server. + + Example: + >>> import asyncio + >>> async def main(): + ... client = ChromaDBClient() + ... await client.acreate_collection( + ... collection_name="documents", + ... metadata={"description": "Product documentation"}, + ... get_or_create=True + ... ) + >>> asyncio.run(main()) + """ + if not _is_async_client(self.client): + raise TypeError( + "Asynchronous method acreate_collection() requires an AsyncClientAPI. " + "Use create_collection() for ClientAPI." + ) + + metadata = kwargs.get("metadata", {}) + if "hnsw:space" not in metadata: + metadata["hnsw:space"] = "cosine" + + await self.client.create_collection( + name=kwargs["collection_name"], + configuration=kwargs.get("configuration"), + metadata=metadata, + embedding_function=kwargs.get( + "embedding_function", self.embedding_function + ), + data_loader=kwargs.get("data_loader"), + get_or_create=kwargs.get("get_or_create", False), + ) + + def get_or_create_collection( + self, **kwargs: Unpack[ChromaDBCollectionCreateParams] + ) -> Any: + """Get an existing collection or create it if it doesn't exist. + + Returns existing collection if found, otherwise creates a new one. + + Keyword Args: + collection_name: Name of the collection to get or create. + configuration: Optional collection configuration specifying distance metrics, + HNSW parameters, or other backend-specific settings. + metadata: Optional metadata dictionary to attach to the collection. + embedding_function: Optional custom embedding function. If not provided, + uses the client's default embedding function. + data_loader: Optional data loader for batch loading data into the collection. + + Returns: + A ChromaDB Collection object. + + Raises: + TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations. + ConnectionError: If unable to connect to ChromaDB server. + + Example: + >>> client = ChromaDBClient() + >>> collection = client.get_or_create_collection( + ... collection_name="documents", + ... metadata={"description": "Product documentation"} + ... ) + """ + if not _is_sync_client(self.client): + raise TypeError( + "Synchronous method get_or_create_collection() requires a ClientAPI. " + "Use aget_or_create_collection() for AsyncClientAPI." + ) + + metadata = kwargs.get("metadata", {}) + if "hnsw:space" not in metadata: + metadata["hnsw:space"] = "cosine" + + return self.client.get_or_create_collection( + name=kwargs["collection_name"], + configuration=kwargs.get("configuration"), + metadata=metadata, + embedding_function=kwargs.get( + "embedding_function", self.embedding_function + ), + data_loader=kwargs.get("data_loader"), + ) + + async def aget_or_create_collection( + self, **kwargs: Unpack[ChromaDBCollectionCreateParams] + ) -> Any: + """Get an existing collection or create it if it doesn't exist asynchronously. + + Returns existing collection if found, otherwise creates a new one. + + Keyword Args: + collection_name: Name of the collection to get or create. + configuration: Optional collection configuration specifying distance metrics, + HNSW parameters, or other backend-specific settings. + metadata: Optional metadata dictionary to attach to the collection. + embedding_function: Optional custom embedding function. If not provided, + uses the client's default embedding function. + data_loader: Optional data loader for batch loading data into the collection. + + Returns: + A ChromaDB AsyncCollection object. + + Raises: + TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations. + ConnectionError: If unable to connect to ChromaDB server. + + Example: + >>> import asyncio + >>> async def main(): + ... client = ChromaDBClient() + ... collection = await client.aget_or_create_collection( + ... collection_name="documents", + ... metadata={"description": "Product documentation"} + ... ) + >>> asyncio.run(main()) + """ + if not _is_async_client(self.client): + raise TypeError( + "Asynchronous method aget_or_create_collection() requires an AsyncClientAPI. " + "Use get_or_create_collection() for ClientAPI." + ) + + metadata = kwargs.get("metadata", {}) + if "hnsw:space" not in metadata: + metadata["hnsw:space"] = "cosine" + + return await self.client.get_or_create_collection( + name=kwargs["collection_name"], + configuration=kwargs.get("configuration"), + metadata=metadata, + embedding_function=kwargs.get( + "embedding_function", self.embedding_function + ), + data_loader=kwargs.get("data_loader"), + ) + + def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: + """Add documents with their embeddings to a collection. + + Performs an upsert operation - documents with existing IDs are updated. + Generates embeddings automatically using the configured embedding function. + + Keyword Args: + collection_name: The name of the collection to add documents to. + documents: List of BaseRecord dicts containing: + - content: The text content (required) + - doc_id: Optional unique identifier (auto-generated if missing) + - metadata: Optional metadata dictionary + + Raises: + TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations. + ValueError: If collection doesn't exist or documents list is empty. + ConnectionError: If unable to connect to ChromaDB server. + """ + if not _is_sync_client(self.client): + raise TypeError( + "Synchronous method add_documents() requires a ClientAPI. " + "Use aadd_documents() for AsyncClientAPI." + ) + + collection_name = kwargs["collection_name"] + documents = kwargs["documents"] + + if not documents: + raise ValueError("Documents list cannot be empty") + + collection = self.client.get_collection( + name=collection_name, + embedding_function=self.embedding_function, + ) + + prepared = _prepare_documents_for_chromadb(documents) + collection.add( + ids=prepared.ids, + documents=prepared.texts, + metadatas=prepared.metadatas, + ) + + async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: + """Add documents with their embeddings to a collection asynchronously. + + Performs an upsert operation - documents with existing IDs are updated. + Generates embeddings automatically using the configured embedding function. + + Keyword Args: + collection_name: The name of the collection to add documents to. + documents: List of BaseRecord dicts containing: + - content: The text content (required) + - doc_id: Optional unique identifier (auto-generated if missing) + - metadata: Optional metadata dictionary + + Raises: + TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations. + ValueError: If collection doesn't exist or documents list is empty. + ConnectionError: If unable to connect to ChromaDB server. + """ + if not _is_async_client(self.client): + raise TypeError( + "Asynchronous method aadd_documents() requires an AsyncClientAPI. " + "Use add_documents() for ClientAPI." + ) + + collection_name = kwargs["collection_name"] + documents = kwargs["documents"] + + if not documents: + raise ValueError("Documents list cannot be empty") + + collection = await self.client.get_collection( + name=collection_name, + embedding_function=self.embedding_function, + ) + prepared = _prepare_documents_for_chromadb(documents) + await collection.add( + ids=prepared.ids, + documents=prepared.texts, + metadatas=prepared.metadatas, + ) + + def search( + self, **kwargs: Unpack[ChromaDBCollectionSearchParams] + ) -> list[SearchResult]: + """Search for similar documents using a query. + + Performs semantic search to find documents similar to the query text. + Uses the configured embedding function to generate query embeddings. + + Keyword Args: + collection_name: Name of the collection to search in. + query: The text query to search for. + limit: Maximum number of results to return (default: 10). + metadata_filter: Optional filter for metadata fields. + score_threshold: Optional minimum similarity score (0-1) for results. + where: Optional ChromaDB where clause for metadata filtering. + where_document: Optional ChromaDB where clause for document content filtering. + include: Optional list of fields to include in results. + + Returns: + List of SearchResult dicts containing id, content, metadata, and score. + + Raises: + TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations. + ValueError: If collection doesn't exist. + ConnectionError: If unable to connect to ChromaDB server. + """ + if not _is_sync_client(self.client): + raise TypeError( + "Synchronous method search() requires a ClientAPI. " + "Use asearch() for AsyncClientAPI." + ) + + params = _extract_search_params(kwargs) + + collection = self.client.get_collection( + name=params.collection_name, + embedding_function=self.embedding_function, + ) + + where = params.where if params.where is not None else params.metadata_filter + + results: QueryResult = collection.query( + query_texts=[params.query], + n_results=params.limit, + where=where, + where_document=params.where_document, + include=params.include, + ) + + return _process_query_results( + collection=collection, + results=results, + params=params, + ) + + async def asearch( + self, **kwargs: Unpack[ChromaDBCollectionSearchParams] + ) -> list[SearchResult]: + """Search for similar documents using a query asynchronously. + + Performs semantic search to find documents similar to the query text. + Uses the configured embedding function to generate query embeddings. + + Keyword Args: + collection_name: Name of the collection to search in. + query: The text query to search for. + limit: Maximum number of results to return (default: 10). + metadata_filter: Optional filter for metadata fields. + score_threshold: Optional minimum similarity score (0-1) for results. + where: Optional ChromaDB where clause for metadata filtering. + where_document: Optional ChromaDB where clause for document content filtering. + include: Optional list of fields to include in results. + + Returns: + List of SearchResult dicts containing id, content, metadata, and score. + + Raises: + TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations. + ValueError: If collection doesn't exist. + ConnectionError: If unable to connect to ChromaDB server. + """ + if not _is_async_client(self.client): + raise TypeError( + "Asynchronous method asearch() requires an AsyncClientAPI. " + "Use search() for ClientAPI." + ) + + params = _extract_search_params(kwargs) + + collection = await self.client.get_collection( + name=params.collection_name, + embedding_function=self.embedding_function, + ) + + where = params.where if params.where is not None else params.metadata_filter + + results: QueryResult = await collection.query( + query_texts=[params.query], + n_results=params.limit, + where=where, + where_document=params.where_document, + include=params.include, + ) + + return _process_query_results( + collection=collection, + results=results, + params=params, + ) + + def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: + """Delete a collection and all its data. + + Permanently removes a collection and all documents, embeddings, and metadata it contains. + This operation cannot be undone. + + Keyword Args: + collection_name: Name of the collection to delete. + + Raises: + TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations. + ValueError: If collection doesn't exist. + ConnectionError: If unable to connect to ChromaDB server. + + Example: + >>> client = ChromaDBClient() + >>> client.delete_collection(collection_name="old_documents") + """ + if not _is_sync_client(self.client): + raise TypeError( + "Synchronous method delete_collection() requires a ClientAPI. " + "Use adelete_collection() for AsyncClientAPI." + ) + + collection_name = kwargs["collection_name"] + self.client.delete_collection(name=collection_name) + + async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: + """Delete a collection and all its data asynchronously. + + Permanently removes a collection and all documents, embeddings, and metadata it contains. + This operation cannot be undone. + + Keyword Args: + collection_name: Name of the collection to delete. + + Raises: + TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations. + ValueError: If collection doesn't exist. + ConnectionError: If unable to connect to ChromaDB server. + + Example: + >>> import asyncio + >>> async def main(): + ... client = ChromaDBClient() + ... await client.adelete_collection(collection_name="old_documents") + >>> asyncio.run(main()) + """ + if not _is_async_client(self.client): + raise TypeError( + "Asynchronous method adelete_collection() requires an AsyncClientAPI. " + "Use delete_collection() for ClientAPI." + ) + + collection_name = kwargs["collection_name"] + await self.client.delete_collection(name=collection_name) + + def reset(self) -> None: + """Reset the vector database by deleting all collections and data. + + Completely clears the ChromaDB instance, removing all collections, + documents, embeddings, and metadata. This operation cannot be undone. + Use with extreme caution in production environments. + + Raises: + TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations. + ConnectionError: If unable to connect to ChromaDB server. + + Example: + >>> client = ChromaDBClient() + >>> client.reset() # Removes ALL data from ChromaDB + """ + if not _is_sync_client(self.client): + raise TypeError( + "Synchronous method reset() requires a ClientAPI. " + "Use areset() for AsyncClientAPI." + ) + + self.client.reset() + + async def areset(self) -> None: + """Reset the vector database by deleting all collections and data asynchronously. + + Completely clears the ChromaDB instance, removing all collections, + documents, embeddings, and metadata. This operation cannot be undone. + Use with extreme caution in production environments. + + Raises: + TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations. + ConnectionError: If unable to connect to ChromaDB server. + + Example: + >>> import asyncio + >>> async def main(): + ... client = ChromaDBClient() + ... await client.areset() # Removes ALL data from ChromaDB + >>> asyncio.run(main()) + """ + if not _is_async_client(self.client): + raise TypeError( + "Asynchronous method areset() requires an AsyncClientAPI. " + "Use reset() for ClientAPI." + ) + + await self.client.reset() diff --git a/src/crewai/rag/chromadb/types.py b/src/crewai/rag/chromadb/types.py new file mode 100644 index 000000000..54a03df39 --- /dev/null +++ b/src/crewai/rag/chromadb/types.py @@ -0,0 +1,85 @@ +"""Type definitions specific to ChromaDB implementation.""" + +from collections.abc import Mapping +from typing import Any, NamedTuple + +from chromadb.api import ClientAPI, AsyncClientAPI +from chromadb.api.configuration import CollectionConfigurationInterface +from chromadb.api.types import ( + CollectionMetadata, + DataLoader, + Embeddable, + EmbeddingFunction as ChromaEmbeddingFunction, + Include, + Loadable, + Where, + WhereDocument, +) + +from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams + +ChromaDBClientType = ClientAPI | AsyncClientAPI + + +class PreparedDocuments(NamedTuple): + """Prepared documents ready for ChromaDB insertion. + + Attributes: + ids: List of document IDs + texts: List of document texts + metadatas: List of document metadata mappings + """ + + ids: list[str] + texts: list[str] + metadatas: list[Mapping[str, str | int | float | bool]] + + +class ExtractedSearchParams(NamedTuple): + """Extracted search parameters for ChromaDB queries. + + Attributes: + collection_name: Name of the collection to search + query: Search query text + limit: Maximum number of results + metadata_filter: Optional metadata filter + score_threshold: Optional minimum similarity score + where: Optional ChromaDB where clause + where_document: Optional ChromaDB document filter + include: Fields to include in results + """ + + collection_name: str + query: str + limit: int + metadata_filter: dict[str, Any] | None + score_threshold: float | None + where: Where | None + where_document: WhereDocument | None + include: Include + + +class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False): + """Parameters for creating a ChromaDB collection. + + This class extends BaseCollectionParams to include any additional + parameters specific to ChromaDB collection creation. + """ + + configuration: CollectionConfigurationInterface + metadata: CollectionMetadata + embedding_function: ChromaEmbeddingFunction[Embeddable] + data_loader: DataLoader[Loadable] + get_or_create: bool + + +class ChromaDBCollectionSearchParams(BaseCollectionSearchParams, total=False): + """Parameters for searching a ChromaDB collection. + + This class extends BaseCollectionSearchParams to include ChromaDB-specific + search parameters like where clauses and include options. + """ + + where: Where + where_document: WhereDocument + include: Include diff --git a/src/crewai/rag/chromadb/utils.py b/src/crewai/rag/chromadb/utils.py new file mode 100644 index 000000000..f7e5c4ebd --- /dev/null +++ b/src/crewai/rag/chromadb/utils.py @@ -0,0 +1,220 @@ +"""Utility functions for ChromaDB client implementation.""" + +import hashlib +from collections.abc import Mapping +from typing import Literal, TypeGuard, cast + +from chromadb.api import AsyncClientAPI, ClientAPI +from chromadb.api.types import ( + Include, + IncludeEnum, + QueryResult, +) + +from chromadb.api.models.AsyncCollection import AsyncCollection +from chromadb.api.models.Collection import Collection + +from crewai.rag.chromadb.types import ( + ChromaDBClientType, + ChromaDBCollectionSearchParams, + ExtractedSearchParams, + PreparedDocuments, +) +from crewai.rag.types import BaseRecord, SearchResult + + +def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]: + """Type guard to check if the client is a synchronous ClientAPI. + + Args: + client: The client to check. + + Returns: + True if the client is a ClientAPI, False otherwise. + """ + return isinstance(client, ClientAPI) + + +def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]: + """Type guard to check if the client is an asynchronous AsyncClientAPI. + + Args: + client: The client to check. + + Returns: + True if the client is an AsyncClientAPI, False otherwise. + """ + return isinstance(client, AsyncClientAPI) + + +def _prepare_documents_for_chromadb( + documents: list[BaseRecord], +) -> PreparedDocuments: + """Prepare documents for ChromaDB by extracting IDs, texts, and metadata. + + Args: + documents: List of BaseRecord documents to prepare. + + Returns: + PreparedDocuments with ids, texts, and metadatas ready for ChromaDB. + """ + ids: list[str] = [] + texts: list[str] = [] + metadatas: list[Mapping[str, str | int | float | bool]] = [] + + for doc in documents: + if "doc_id" in doc: + ids.append(doc["doc_id"]) + else: + content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16] + ids.append(content_hash) + + texts.append(doc["content"]) + metadata = doc.get("metadata") + if metadata: + if isinstance(metadata, list): + metadatas.append(metadata[0] if metadata else {}) + else: + metadatas.append(metadata) + else: + metadatas.append({}) + + return PreparedDocuments(ids, texts, metadatas) + + +def _extract_search_params( + kwargs: ChromaDBCollectionSearchParams, +) -> ExtractedSearchParams: + """Extract search parameters from kwargs. + + Args: + kwargs: Keyword arguments containing search parameters. + + Returns: + ExtractedSearchParams with all extracted parameters. + """ + return ExtractedSearchParams( + collection_name=kwargs["collection_name"], + query=kwargs["query"], + limit=kwargs.get("limit", 10), + metadata_filter=kwargs.get("metadata_filter"), + score_threshold=kwargs.get("score_threshold"), + where=kwargs.get("where"), + where_document=kwargs.get("where_document"), + include=kwargs.get( + "include", + [IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances], + ), + ) + + +def _convert_distance_to_score( + distance: float, + distance_metric: Literal["l2", "cosine", "ip"], +) -> float: + """Convert ChromaDB distance to similarity score. + + Notes: + Assuming all embedding are unit-normalized for now, including custom embeddings. + + Args: + distance: The distance value from ChromaDB. + distance_metric: The distance metric used ("l2", "cosine", or "ip"). + + Returns: + Similarity score in range [0, 1] where 1 is most similar. + """ + if distance_metric == "cosine": + score = 1.0 - 0.5 * distance + return max(0.0, min(1.0, score)) + raise ValueError(f"Unsupported distance metric: {distance_metric}") + + +def _convert_chromadb_results_to_search_results( + results: QueryResult, + include: Include, + distance_metric: Literal["l2", "cosine", "ip"], + score_threshold: float | None = None, +) -> list[SearchResult]: + """Convert ChromaDB query results to SearchResult format. + + Args: + results: ChromaDB query results. + include: List of fields that were included in the query. + distance_metric: The distance metric used by the collection. + score_threshold: Optional minimum similarity score (0-1) for results. + + Returns: + List of SearchResult dicts containing id, content, metadata, and score. + """ + search_results: list[SearchResult] = [] + + include_strings = [item.value for item in include] + + ids = results["ids"][0] if results.get("ids") else [] + + documents_list = results.get("documents") + documents = ( + documents_list[0] if documents_list and "documents" in include_strings else [] + ) + + metadatas_list = results.get("metadatas") + metadatas = ( + metadatas_list[0] if metadatas_list and "metadatas" in include_strings else [] + ) + + distances_list = results.get("distances") + distances = ( + distances_list[0] if distances_list and "distances" in include_strings else [] + ) + + for i, doc_id in enumerate(ids): + if not distances or i >= len(distances): + continue + + distance = distances[i] + score = _convert_distance_to_score( + distance=distance, distance_metric=distance_metric + ) + + if score_threshold and score < score_threshold: + continue + + result: SearchResult = { + "id": doc_id, + "content": documents[i] if documents and i < len(documents) else "", + "metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {}, + "score": score, + } + search_results.append(result) + + return search_results + + +def _process_query_results( + collection: Collection | AsyncCollection, + results: QueryResult, + params: ExtractedSearchParams, +) -> list[SearchResult]: + """Process ChromaDB query results and convert to SearchResult format. + + Args: + collection: The ChromaDB collection (sync or async) that was queried. + results: Raw query results from ChromaDB. + params: The search parameters used for the query. + + Returns: + List of SearchResult dicts containing id, content, metadata, and score. + """ + + distance_metric = cast( + Literal["l2", "cosine", "ip"], + collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2", + ) + + return _convert_chromadb_results_to_search_results( + results=results, + include=params.include, + distance_metric=distance_metric, + score_threshold=params.score_threshold, + ) diff --git a/tests/rag/__init__.py b/tests/rag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/rag/chromadb/__init__.py b/tests/rag/chromadb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/rag/chromadb/test_client.py b/tests/rag/chromadb/test_client.py new file mode 100644 index 000000000..67d58614e --- /dev/null +++ b/tests/rag/chromadb/test_client.py @@ -0,0 +1,550 @@ +"""Tests for ChromaDBClient implementation.""" + +from unittest.mock import AsyncMock, Mock + +import pytest + +from crewai.rag.chromadb.client import ChromaDBClient +from crewai.rag.types import BaseRecord + + +@pytest.fixture +def mock_chromadb_client(): + """Create a mock ChromaDB client.""" + from chromadb.api import ClientAPI + + return Mock(spec=ClientAPI) + + +@pytest.fixture +def mock_async_chromadb_client(): + """Create a mock async ChromaDB client.""" + from chromadb.api import AsyncClientAPI + + return Mock(spec=AsyncClientAPI) + + +@pytest.fixture +def client(mock_chromadb_client) -> ChromaDBClient: + """Create a ChromaDBClient instance for testing.""" + client = ChromaDBClient() + client.client = mock_chromadb_client + client.embedding_function = Mock() + return client + + +@pytest.fixture +def async_client(mock_async_chromadb_client) -> ChromaDBClient: + """Create a ChromaDBClient instance with async client for testing.""" + client = ChromaDBClient() + client.client = mock_async_chromadb_client + client.embedding_function = Mock() + return client + + +class TestChromaDBClient: + """Test suite for ChromaDBClient.""" + + def test_create_collection(self, client, mock_chromadb_client): + """Test that create_collection calls the underlying client correctly.""" + client.create_collection(collection_name="test_collection") + + mock_chromadb_client.create_collection.assert_called_once_with( + name="test_collection", + configuration=None, + metadata={"hnsw:space": "cosine"}, + embedding_function=client.embedding_function, + data_loader=None, + get_or_create=False, + ) + + def test_create_collection_with_all_params(self, client, mock_chromadb_client): + """Test create_collection with all optional parameters.""" + mock_config = Mock() + mock_metadata = {"key": "value"} + mock_embedding_func = Mock() + mock_data_loader = Mock() + + client.create_collection( + collection_name="test_collection", + configuration=mock_config, + metadata=mock_metadata, + embedding_function=mock_embedding_func, + data_loader=mock_data_loader, + get_or_create=True, + ) + + mock_chromadb_client.create_collection.assert_called_once_with( + name="test_collection", + configuration=mock_config, + metadata=mock_metadata, + embedding_function=mock_embedding_func, + data_loader=mock_data_loader, + get_or_create=True, + ) + + @pytest.mark.asyncio + async def test_acreate_collection( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test that acreate_collection calls the underlying client correctly.""" + # Make the mock's create_collection an AsyncMock + mock_async_chromadb_client.create_collection = AsyncMock(return_value=None) + + await async_client.acreate_collection(collection_name="test_collection") + + mock_async_chromadb_client.create_collection.assert_called_once_with( + name="test_collection", + configuration=None, + metadata={"hnsw:space": "cosine"}, + embedding_function=async_client.embedding_function, + data_loader=None, + get_or_create=False, + ) + + @pytest.mark.asyncio + async def test_acreate_collection_with_all_params( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test acreate_collection with all optional parameters.""" + # Make the mock's create_collection an AsyncMock + mock_async_chromadb_client.create_collection = AsyncMock(return_value=None) + + mock_config = Mock() + mock_metadata = {"key": "value"} + mock_embedding_func = Mock() + mock_data_loader = Mock() + + await async_client.acreate_collection( + collection_name="test_collection", + configuration=mock_config, + metadata=mock_metadata, + embedding_function=mock_embedding_func, + data_loader=mock_data_loader, + get_or_create=True, + ) + + mock_async_chromadb_client.create_collection.assert_called_once_with( + name="test_collection", + configuration=mock_config, + metadata=mock_metadata, + embedding_function=mock_embedding_func, + data_loader=mock_data_loader, + get_or_create=True, + ) + + def test_get_or_create_collection(self, client, mock_chromadb_client): + """Test that get_or_create_collection calls the underlying client correctly.""" + mock_collection = Mock() + mock_chromadb_client.get_or_create_collection.return_value = mock_collection + + result = client.get_or_create_collection(collection_name="test_collection") + + mock_chromadb_client.get_or_create_collection.assert_called_once_with( + name="test_collection", + configuration=None, + metadata={"hnsw:space": "cosine"}, + embedding_function=client.embedding_function, + data_loader=None, + ) + assert result == mock_collection + + def test_get_or_create_collection_with_all_params( + self, client, mock_chromadb_client + ): + """Test get_or_create_collection with all optional parameters.""" + mock_collection = Mock() + mock_chromadb_client.get_or_create_collection.return_value = mock_collection + mock_config = Mock() + mock_metadata = {"key": "value"} + mock_embedding_func = Mock() + mock_data_loader = Mock() + + result = client.get_or_create_collection( + collection_name="test_collection", + configuration=mock_config, + metadata=mock_metadata, + embedding_function=mock_embedding_func, + data_loader=mock_data_loader, + ) + + mock_chromadb_client.get_or_create_collection.assert_called_once_with( + name="test_collection", + configuration=mock_config, + metadata=mock_metadata, + embedding_function=mock_embedding_func, + data_loader=mock_data_loader, + ) + assert result == mock_collection + + @pytest.mark.asyncio + async def test_aget_or_create_collection( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test that aget_or_create_collection calls the underlying client correctly.""" + mock_collection = Mock() + mock_async_chromadb_client.get_or_create_collection = AsyncMock( + return_value=mock_collection + ) + + result = await async_client.aget_or_create_collection( + collection_name="test_collection" + ) + + mock_async_chromadb_client.get_or_create_collection.assert_called_once_with( + name="test_collection", + configuration=None, + metadata={"hnsw:space": "cosine"}, + embedding_function=async_client.embedding_function, + data_loader=None, + ) + assert result == mock_collection + + @pytest.mark.asyncio + async def test_aget_or_create_collection_with_all_params( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test aget_or_create_collection with all optional parameters.""" + mock_collection = Mock() + mock_async_chromadb_client.get_or_create_collection = AsyncMock( + return_value=mock_collection + ) + mock_config = Mock() + mock_metadata = {"key": "value"} + mock_embedding_func = Mock() + mock_data_loader = Mock() + + result = await async_client.aget_or_create_collection( + collection_name="test_collection", + configuration=mock_config, + metadata=mock_metadata, + embedding_function=mock_embedding_func, + data_loader=mock_data_loader, + ) + + mock_async_chromadb_client.get_or_create_collection.assert_called_once_with( + name="test_collection", + configuration=mock_config, + metadata=mock_metadata, + embedding_function=mock_embedding_func, + data_loader=mock_data_loader, + ) + assert result == mock_collection + + def test_add_documents(self, client, mock_chromadb_client) -> None: + """Test that add_documents adds documents to collection.""" + mock_collection = Mock() + mock_chromadb_client.get_collection.return_value = mock_collection + + documents: list[BaseRecord] = [ + { + "content": "Test document", + "metadata": {"source": "test"}, + } + ] + + client.add_documents(collection_name="test_collection", documents=documents) + + mock_chromadb_client.get_collection.assert_called_once_with( + name="test_collection", + embedding_function=client.embedding_function, + ) + + # Verify documents were added to collection + mock_collection.add.assert_called_once() + call_args = mock_collection.add.call_args + assert len(call_args.kwargs["ids"]) == 1 + assert call_args.kwargs["documents"] == ["Test document"] + assert call_args.kwargs["metadatas"] == [{"source": "test"}] + + def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None: + """Test add_documents with custom document IDs.""" + mock_collection = Mock() + mock_chromadb_client.get_collection.return_value = mock_collection + + documents: list[BaseRecord] = [ + { + "doc_id": "custom_id_1", + "content": "First document", + "metadata": {"source": "test1"}, + }, + { + "doc_id": "custom_id_2", + "content": "Second document", + "metadata": {"source": "test2"}, + }, + ] + + client.add_documents(collection_name="test_collection", documents=documents) + + mock_collection.add.assert_called_once_with( + ids=["custom_id_1", "custom_id_2"], + documents=["First document", "Second document"], + metadatas=[{"source": "test1"}, {"source": "test2"}], + ) + + def test_add_documents_empty_list_raises_error( + self, client, mock_chromadb_client + ) -> None: + """Test that add_documents raises error for empty documents list.""" + with pytest.raises(ValueError, match="Documents list cannot be empty"): + client.add_documents(collection_name="test_collection", documents=[]) + + @pytest.mark.asyncio + async def test_aadd_documents( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test that aadd_documents adds documents to collection asynchronously.""" + mock_collection = AsyncMock() + mock_async_chromadb_client.get_collection = AsyncMock( + return_value=mock_collection + ) + + documents: list[BaseRecord] = [ + { + "content": "Test document", + "metadata": {"source": "test"}, + } + ] + + await async_client.aadd_documents( + collection_name="test_collection", documents=documents + ) + + mock_async_chromadb_client.get_collection.assert_called_once_with( + name="test_collection", + embedding_function=async_client.embedding_function, + ) + + # Verify documents were added to collection + mock_collection.add.assert_called_once() + call_args = mock_collection.add.call_args + assert len(call_args.kwargs["ids"]) == 1 + assert call_args.kwargs["documents"] == ["Test document"] + assert call_args.kwargs["metadatas"] == [{"source": "test"}] + + @pytest.mark.asyncio + async def test_aadd_documents_with_custom_ids( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test aadd_documents with custom document IDs.""" + mock_collection = AsyncMock() + mock_async_chromadb_client.get_collection = AsyncMock( + return_value=mock_collection + ) + + documents: list[BaseRecord] = [ + { + "doc_id": "custom_id_1", + "content": "First document", + "metadata": {"source": "test1"}, + }, + { + "doc_id": "custom_id_2", + "content": "Second document", + "metadata": {"source": "test2"}, + }, + ] + + await async_client.aadd_documents( + collection_name="test_collection", documents=documents + ) + + mock_collection.add.assert_called_once_with( + ids=["custom_id_1", "custom_id_2"], + documents=["First document", "Second document"], + metadatas=[{"source": "test1"}, {"source": "test2"}], + ) + + @pytest.mark.asyncio + async def test_aadd_documents_empty_list_raises_error( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test that aadd_documents raises error for empty documents list.""" + with pytest.raises(ValueError, match="Documents list cannot be empty"): + await async_client.aadd_documents( + collection_name="test_collection", documents=[] + ) + + def test_search(self, client, mock_chromadb_client): + """Test that search queries the collection correctly.""" + mock_collection = Mock() + mock_collection.metadata = {"hnsw:space": "cosine"} + mock_chromadb_client.get_collection.return_value = mock_collection + mock_collection.query.return_value = { + "ids": [["doc1", "doc2"]], + "documents": [["Document 1", "Document 2"]], + "metadatas": [[{"source": "test1"}, {"source": "test2"}]], + "distances": [[0.1, 0.3]], + } + + results = client.search(collection_name="test_collection", query="test query") + + mock_chromadb_client.get_collection.assert_called_once_with( + name="test_collection", + embedding_function=client.embedding_function, + ) + mock_collection.query.assert_called_once_with( + query_texts=["test query"], + n_results=10, + where=None, + where_document=None, + include=["metadatas", "documents", "distances"], + ) + + assert len(results) == 2 + assert results[0]["id"] == "doc1" + assert results[0]["content"] == "Document 1" + assert results[0]["metadata"] == {"source": "test1"} + assert results[0]["score"] == 0.95 + + def test_search_with_optional_params(self, client, mock_chromadb_client): + """Test search with optional parameters.""" + mock_collection = Mock() + mock_collection.metadata = {"hnsw:space": "cosine"} + mock_chromadb_client.get_collection.return_value = mock_collection + mock_collection.query.return_value = { + "ids": [["doc1", "doc2", "doc3"]], + "documents": [["Document 1", "Document 2", "Document 3"]], + "metadatas": [ + [{"source": "test1"}, {"source": "test2"}, {"source": "test3"}] + ], + "distances": [[0.1, 0.3, 1.5]], # Last one will be filtered by threshold + } + + results = client.search( + collection_name="test_collection", + query="test query", + limit=5, + metadata_filter={"source": "test"}, + score_threshold=0.7, + ) + + mock_collection.query.assert_called_once_with( + query_texts=["test query"], + n_results=5, + where={"source": "test"}, + where_document=None, + include=["metadatas", "documents", "distances"], + ) + + assert len(results) == 2 + + @pytest.mark.asyncio + async def test_asearch(self, async_client, mock_async_chromadb_client) -> None: + """Test that asearch queries the collection correctly.""" + mock_collection = AsyncMock() + mock_collection.metadata = {"hnsw:space": "cosine"} + mock_async_chromadb_client.get_collection = AsyncMock( + return_value=mock_collection + ) + mock_collection.query = AsyncMock( + return_value={ + "ids": [["doc1", "doc2"]], + "documents": [["Document 1", "Document 2"]], + "metadatas": [[{"source": "test1"}, {"source": "test2"}]], + "distances": [[0.1, 0.3]], + } + ) + + results = await async_client.asearch( + collection_name="test_collection", query="test query" + ) + + mock_async_chromadb_client.get_collection.assert_called_once_with( + name="test_collection", + embedding_function=async_client.embedding_function, + ) + mock_collection.query.assert_called_once_with( + query_texts=["test query"], + n_results=10, + where=None, + where_document=None, + include=["metadatas", "documents", "distances"], + ) + + assert len(results) == 2 + assert results[0]["id"] == "doc1" + assert results[0]["content"] == "Document 1" + assert results[0]["metadata"] == {"source": "test1"} + assert results[0]["score"] == 0.95 + + @pytest.mark.asyncio + async def test_asearch_with_optional_params( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test asearch with optional parameters.""" + mock_collection = AsyncMock() + mock_collection.metadata = {"hnsw:space": "cosine"} + mock_async_chromadb_client.get_collection = AsyncMock( + return_value=mock_collection + ) + mock_collection.query = AsyncMock( + return_value={ + "ids": [["doc1", "doc2", "doc3"]], + "documents": [["Document 1", "Document 2", "Document 3"]], + "metadatas": [ + [{"source": "test1"}, {"source": "test2"}, {"source": "test3"}] + ], + "distances": [ + [0.1, 0.3, 1.5] + ], # Last one will be filtered by threshold + } + ) + + results = await async_client.asearch( + collection_name="test_collection", + query="test query", + limit=5, + metadata_filter={"source": "test"}, + score_threshold=0.7, + ) + + mock_collection.query.assert_called_once_with( + query_texts=["test query"], + n_results=5, + where={"source": "test"}, + where_document=None, + include=["metadatas", "documents", "distances"], + ) + + # Only 2 results should pass the score threshold + assert len(results) == 2 + + def test_delete_collection(self, client, mock_chromadb_client): + """Test that delete_collection calls the underlying client correctly.""" + client.delete_collection(collection_name="test_collection") + + mock_chromadb_client.delete_collection.assert_called_once_with( + name="test_collection" + ) + + @pytest.mark.asyncio + async def test_adelete_collection( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test that adelete_collection calls the underlying client correctly.""" + mock_async_chromadb_client.delete_collection = AsyncMock(return_value=None) + + await async_client.adelete_collection(collection_name="test_collection") + + mock_async_chromadb_client.delete_collection.assert_called_once_with( + name="test_collection" + ) + + def test_reset(self, client, mock_chromadb_client): + """Test that reset calls the underlying client correctly.""" + mock_chromadb_client.reset.return_value = True + + client.reset() + + mock_chromadb_client.reset.assert_called_once_with() + + @pytest.mark.asyncio + async def test_areset(self, async_client, mock_async_chromadb_client) -> None: + """Test that areset calls the underlying client correctly.""" + mock_async_chromadb_client.reset = AsyncMock(return_value=True) + + await async_client.areset() + + mock_async_chromadb_client.reset.assert_called_once_with()