feat: chromadb generic client (#3374)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

Add ChromaDB client implementation with async support

- Implement core collection operations (create, get_or_create, delete)
- Add search functionality with cosine similarity scoring
- Include both sync and async method variants
- Add type safety with NamedTuples and TypeGuards
- Extract utility functions to separate modules
- Default to cosine distance metric for text similarity
- Add comprehensive test coverage

TODO:
- l2, ip score calculations are not settled on
This commit is contained in:
Greyson LaLonde
2025-08-21 18:18:46 -04:00
committed by GitHub
parent 1217935b31
commit 842bed4e9c
7 changed files with 1411 additions and 0 deletions

View File

View File

@@ -0,0 +1,556 @@
"""ChromaDB client implementation."""
from typing import Any
from chromadb.api.types import (
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
QueryResult,
)
from typing_extensions import Unpack
from crewai.rag.chromadb.types import (
ChromaDBClientType,
ChromaDBCollectionCreateParams,
ChromaDBCollectionSearchParams,
)
from crewai.rag.chromadb.utils import (
_extract_search_params,
_is_async_client,
_is_sync_client,
_prepare_documents_for_chromadb,
_process_query_results,
)
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionParams,
BaseCollectionAddParams,
)
from crewai.rag.types import SearchResult
class ChromaDBClient(BaseClient):
"""ChromaDB implementation of the BaseClient protocol.
Provides vector database operations for ChromaDB, supporting both
synchronous and asynchronous clients.
Attributes:
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
embedding_function: Function to generate embeddings for documents.
"""
client: ChromaDBClientType
embedding_function: ChromaEmbeddingFunction[Embeddable]
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> None:
"""Create a new collection in ChromaDB.
Uses the client's default embedding function if none provided.
Keyword Args:
collection_name: Name of the collection to create. Must be unique.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
get_or_create: If True, returns existing collection if it already exists
instead of raising an error. Defaults to False.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection with the same name already exists and get_or_create
is False.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> client.create_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"},
... get_or_create=True
... )
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method create_collection() requires a ClientAPI. "
"Use acreate_collection() for AsyncClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
self.client.create_collection(
name=kwargs["collection_name"],
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
get_or_create=kwargs.get("get_or_create", False),
)
async def acreate_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> None:
"""Create a new collection in ChromaDB asynchronously.
Creates a new collection with the specified name and optional configuration.
If an embedding function is not provided, uses the client's default embedding function.
Keyword Args:
collection_name: Name of the collection to create. Must be unique.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
get_or_create: If True, returns existing collection if it already exists
instead of raising an error. Defaults to False.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection with the same name already exists and get_or_create
is False.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... await client.acreate_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"},
... get_or_create=True
... )
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method acreate_collection() requires an AsyncClientAPI. "
"Use create_collection() for ClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
await self.client.create_collection(
name=kwargs["collection_name"],
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
get_or_create=kwargs.get("get_or_create", False),
)
def get_or_create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist.
Returns existing collection if found, otherwise creates a new one.
Keyword Args:
collection_name: Name of the collection to get or create.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
Returns:
A ChromaDB Collection object.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> collection = client.get_or_create_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"}
... )
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method get_or_create_collection() requires a ClientAPI. "
"Use aget_or_create_collection() for AsyncClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
return self.client.get_or_create_collection(
name=kwargs["collection_name"],
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
)
async def aget_or_create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist asynchronously.
Returns existing collection if found, otherwise creates a new one.
Keyword Args:
collection_name: Name of the collection to get or create.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
Returns:
A ChromaDB AsyncCollection object.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... collection = await client.aget_or_create_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"}
... )
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method aget_or_create_collection() requires an AsyncClientAPI. "
"Use get_or_create_collection() for ClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
return await self.client.get_or_create_collection(
name=kwargs["collection_name"],
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
)
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection.
Performs an upsert operation - documents with existing IDs are updated.
Generates embeddings automatically using the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection doesn't exist or documents list is empty.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method add_documents() requires a ClientAPI. "
"Use aadd_documents() for AsyncClientAPI."
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
collection = self.client.get_collection(
name=collection_name,
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
collection.add(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
Performs an upsert operation - documents with existing IDs are updated.
Generates embeddings automatically using the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection doesn't exist or documents list is empty.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method aadd_documents() requires an AsyncClientAPI. "
"Use add_documents() for ClientAPI."
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
if not documents:
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_collection(
name=collection_name,
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
await collection.add(
ids=prepared.ids,
documents=prepared.texts,
metadatas=prepared.metadatas,
)
def search(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Performs semantic search to find documents similar to the query text.
Uses the configured embedding function to generate query embeddings.
Keyword Args:
collection_name: Name of the collection to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
where: Optional ChromaDB where clause for metadata filtering.
where_document: Optional ChromaDB where clause for document content filtering.
include: Optional list of fields to include in results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method search() requires a ClientAPI. "
"Use asearch() for AsyncClientAPI."
)
params = _extract_search_params(kwargs)
collection = self.client.get_collection(
name=params.collection_name,
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
results: QueryResult = collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
async def asearch(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Performs semantic search to find documents similar to the query text.
Uses the configured embedding function to generate query embeddings.
Keyword Args:
collection_name: Name of the collection to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
where: Optional ChromaDB where clause for metadata filtering.
where_document: Optional ChromaDB where clause for document content filtering.
include: Optional list of fields to include in results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method asearch() requires an AsyncClientAPI. "
"Use search() for ClientAPI."
)
params = _extract_search_params(kwargs)
collection = await self.client.get_collection(
name=params.collection_name,
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
results: QueryResult = await collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
Permanently removes a collection and all documents, embeddings, and metadata it contains.
This operation cannot be undone.
Keyword Args:
collection_name: Name of the collection to delete.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> client.delete_collection(collection_name="old_documents")
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method delete_collection() requires a ClientAPI. "
"Use adelete_collection() for AsyncClientAPI."
)
collection_name = kwargs["collection_name"]
self.client.delete_collection(name=collection_name)
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
Permanently removes a collection and all documents, embeddings, and metadata it contains.
This operation cannot be undone.
Keyword Args:
collection_name: Name of the collection to delete.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... await client.adelete_collection(collection_name="old_documents")
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method adelete_collection() requires an AsyncClientAPI. "
"Use delete_collection() for ClientAPI."
)
collection_name = kwargs["collection_name"]
await self.client.delete_collection(name=collection_name)
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
Completely clears the ChromaDB instance, removing all collections,
documents, embeddings, and metadata. This operation cannot be undone.
Use with extreme caution in production environments.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> client.reset() # Removes ALL data from ChromaDB
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method reset() requires a ClientAPI. "
"Use areset() for AsyncClientAPI."
)
self.client.reset()
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
Completely clears the ChromaDB instance, removing all collections,
documents, embeddings, and metadata. This operation cannot be undone.
Use with extreme caution in production environments.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... await client.areset() # Removes ALL data from ChromaDB
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method areset() requires an AsyncClientAPI. "
"Use reset() for ClientAPI."
)
await self.client.reset()

View File

@@ -0,0 +1,85 @@
"""Type definitions specific to ChromaDB implementation."""
from collections.abc import Mapping
from typing import Any, NamedTuple
from chromadb.api import ClientAPI, AsyncClientAPI
from chromadb.api.configuration import CollectionConfigurationInterface
from chromadb.api.types import (
CollectionMetadata,
DataLoader,
Embeddable,
EmbeddingFunction as ChromaEmbeddingFunction,
Include,
Loadable,
Where,
WhereDocument,
)
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
ChromaDBClientType = ClientAPI | AsyncClientAPI
class PreparedDocuments(NamedTuple):
"""Prepared documents ready for ChromaDB insertion.
Attributes:
ids: List of document IDs
texts: List of document texts
metadatas: List of document metadata mappings
"""
ids: list[str]
texts: list[str]
metadatas: list[Mapping[str, str | int | float | bool]]
class ExtractedSearchParams(NamedTuple):
"""Extracted search parameters for ChromaDB queries.
Attributes:
collection_name: Name of the collection to search
query: Search query text
limit: Maximum number of results
metadata_filter: Optional metadata filter
score_threshold: Optional minimum similarity score
where: Optional ChromaDB where clause
where_document: Optional ChromaDB document filter
include: Fields to include in results
"""
collection_name: str
query: str
limit: int
metadata_filter: dict[str, Any] | None
score_threshold: float | None
where: Where | None
where_document: WhereDocument | None
include: Include
class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
"""Parameters for creating a ChromaDB collection.
This class extends BaseCollectionParams to include any additional
parameters specific to ChromaDB collection creation.
"""
configuration: CollectionConfigurationInterface
metadata: CollectionMetadata
embedding_function: ChromaEmbeddingFunction[Embeddable]
data_loader: DataLoader[Loadable]
get_or_create: bool
class ChromaDBCollectionSearchParams(BaseCollectionSearchParams, total=False):
"""Parameters for searching a ChromaDB collection.
This class extends BaseCollectionSearchParams to include ChromaDB-specific
search parameters like where clauses and include options.
"""
where: Where
where_document: WhereDocument
include: Include

View File

@@ -0,0 +1,220 @@
"""Utility functions for ChromaDB client implementation."""
import hashlib
from collections.abc import Mapping
from typing import Literal, TypeGuard, cast
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.types import (
Include,
IncludeEnum,
QueryResult,
)
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from crewai.rag.chromadb.types import (
ChromaDBClientType,
ChromaDBCollectionSearchParams,
ExtractedSearchParams,
PreparedDocuments,
)
from crewai.rag.types import BaseRecord, SearchResult
def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]:
"""Type guard to check if the client is a synchronous ClientAPI.
Args:
client: The client to check.
Returns:
True if the client is a ClientAPI, False otherwise.
"""
return isinstance(client, ClientAPI)
def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]:
"""Type guard to check if the client is an asynchronous AsyncClientAPI.
Args:
client: The client to check.
Returns:
True if the client is an AsyncClientAPI, False otherwise.
"""
return isinstance(client, AsyncClientAPI)
def _prepare_documents_for_chromadb(
documents: list[BaseRecord],
) -> PreparedDocuments:
"""Prepare documents for ChromaDB by extracting IDs, texts, and metadata.
Args:
documents: List of BaseRecord documents to prepare.
Returns:
PreparedDocuments with ids, texts, and metadatas ready for ChromaDB.
"""
ids: list[str] = []
texts: list[str] = []
metadatas: list[Mapping[str, str | int | float | bool]] = []
for doc in documents:
if "doc_id" in doc:
ids.append(doc["doc_id"])
else:
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
ids.append(content_hash)
texts.append(doc["content"])
metadata = doc.get("metadata")
if metadata:
if isinstance(metadata, list):
metadatas.append(metadata[0] if metadata else {})
else:
metadatas.append(metadata)
else:
metadatas.append({})
return PreparedDocuments(ids, texts, metadatas)
def _extract_search_params(
kwargs: ChromaDBCollectionSearchParams,
) -> ExtractedSearchParams:
"""Extract search parameters from kwargs.
Args:
kwargs: Keyword arguments containing search parameters.
Returns:
ExtractedSearchParams with all extracted parameters.
"""
return ExtractedSearchParams(
collection_name=kwargs["collection_name"],
query=kwargs["query"],
limit=kwargs.get("limit", 10),
metadata_filter=kwargs.get("metadata_filter"),
score_threshold=kwargs.get("score_threshold"),
where=kwargs.get("where"),
where_document=kwargs.get("where_document"),
include=kwargs.get(
"include",
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
),
)
def _convert_distance_to_score(
distance: float,
distance_metric: Literal["l2", "cosine", "ip"],
) -> float:
"""Convert ChromaDB distance to similarity score.
Notes:
Assuming all embedding are unit-normalized for now, including custom embeddings.
Args:
distance: The distance value from ChromaDB.
distance_metric: The distance metric used ("l2", "cosine", or "ip").
Returns:
Similarity score in range [0, 1] where 1 is most similar.
"""
if distance_metric == "cosine":
score = 1.0 - 0.5 * distance
return max(0.0, min(1.0, score))
raise ValueError(f"Unsupported distance metric: {distance_metric}")
def _convert_chromadb_results_to_search_results(
results: QueryResult,
include: Include,
distance_metric: Literal["l2", "cosine", "ip"],
score_threshold: float | None = None,
) -> list[SearchResult]:
"""Convert ChromaDB query results to SearchResult format.
Args:
results: ChromaDB query results.
include: List of fields that were included in the query.
distance_metric: The distance metric used by the collection.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
"""
search_results: list[SearchResult] = []
include_strings = [item.value for item in include]
ids = results["ids"][0] if results.get("ids") else []
documents_list = results.get("documents")
documents = (
documents_list[0] if documents_list and "documents" in include_strings else []
)
metadatas_list = results.get("metadatas")
metadatas = (
metadatas_list[0] if metadatas_list and "metadatas" in include_strings else []
)
distances_list = results.get("distances")
distances = (
distances_list[0] if distances_list and "distances" in include_strings else []
)
for i, doc_id in enumerate(ids):
if not distances or i >= len(distances):
continue
distance = distances[i]
score = _convert_distance_to_score(
distance=distance, distance_metric=distance_metric
)
if score_threshold and score < score_threshold:
continue
result: SearchResult = {
"id": doc_id,
"content": documents[i] if documents and i < len(documents) else "",
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
"score": score,
}
search_results.append(result)
return search_results
def _process_query_results(
collection: Collection | AsyncCollection,
results: QueryResult,
params: ExtractedSearchParams,
) -> list[SearchResult]:
"""Process ChromaDB query results and convert to SearchResult format.
Args:
collection: The ChromaDB collection (sync or async) that was queried.
results: Raw query results from ChromaDB.
params: The search parameters used for the query.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
"""
distance_metric = cast(
Literal["l2", "cosine", "ip"],
collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2",
)
return _convert_chromadb_results_to_search_results(
results=results,
include=params.include,
distance_metric=distance_metric,
score_threshold=params.score_threshold,
)

0
tests/rag/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,550 @@
"""Tests for ChromaDBClient implementation."""
from unittest.mock import AsyncMock, Mock
import pytest
from crewai.rag.chromadb.client import ChromaDBClient
from crewai.rag.types import BaseRecord
@pytest.fixture
def mock_chromadb_client():
"""Create a mock ChromaDB client."""
from chromadb.api import ClientAPI
return Mock(spec=ClientAPI)
@pytest.fixture
def mock_async_chromadb_client():
"""Create a mock async ChromaDB client."""
from chromadb.api import AsyncClientAPI
return Mock(spec=AsyncClientAPI)
@pytest.fixture
def client(mock_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance for testing."""
client = ChromaDBClient()
client.client = mock_chromadb_client
client.embedding_function = Mock()
return client
@pytest.fixture
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with async client for testing."""
client = ChromaDBClient()
client.client = mock_async_chromadb_client
client.embedding_function = Mock()
return client
class TestChromaDBClient:
"""Test suite for ChromaDBClient."""
def test_create_collection(self, client, mock_chromadb_client):
"""Test that create_collection calls the underlying client correctly."""
client.create_collection(collection_name="test_collection")
mock_chromadb_client.create_collection.assert_called_once_with(
name="test_collection",
configuration=None,
metadata={"hnsw:space": "cosine"},
embedding_function=client.embedding_function,
data_loader=None,
get_or_create=False,
)
def test_create_collection_with_all_params(self, client, mock_chromadb_client):
"""Test create_collection with all optional parameters."""
mock_config = Mock()
mock_metadata = {"key": "value"}
mock_embedding_func = Mock()
mock_data_loader = Mock()
client.create_collection(
collection_name="test_collection",
configuration=mock_config,
metadata=mock_metadata,
embedding_function=mock_embedding_func,
data_loader=mock_data_loader,
get_or_create=True,
)
mock_chromadb_client.create_collection.assert_called_once_with(
name="test_collection",
configuration=mock_config,
metadata=mock_metadata,
embedding_function=mock_embedding_func,
data_loader=mock_data_loader,
get_or_create=True,
)
@pytest.mark.asyncio
async def test_acreate_collection(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test that acreate_collection calls the underlying client correctly."""
# Make the mock's create_collection an AsyncMock
mock_async_chromadb_client.create_collection = AsyncMock(return_value=None)
await async_client.acreate_collection(collection_name="test_collection")
mock_async_chromadb_client.create_collection.assert_called_once_with(
name="test_collection",
configuration=None,
metadata={"hnsw:space": "cosine"},
embedding_function=async_client.embedding_function,
data_loader=None,
get_or_create=False,
)
@pytest.mark.asyncio
async def test_acreate_collection_with_all_params(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test acreate_collection with all optional parameters."""
# Make the mock's create_collection an AsyncMock
mock_async_chromadb_client.create_collection = AsyncMock(return_value=None)
mock_config = Mock()
mock_metadata = {"key": "value"}
mock_embedding_func = Mock()
mock_data_loader = Mock()
await async_client.acreate_collection(
collection_name="test_collection",
configuration=mock_config,
metadata=mock_metadata,
embedding_function=mock_embedding_func,
data_loader=mock_data_loader,
get_or_create=True,
)
mock_async_chromadb_client.create_collection.assert_called_once_with(
name="test_collection",
configuration=mock_config,
metadata=mock_metadata,
embedding_function=mock_embedding_func,
data_loader=mock_data_loader,
get_or_create=True,
)
def test_get_or_create_collection(self, client, mock_chromadb_client):
"""Test that get_or_create_collection calls the underlying client correctly."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
result = client.get_or_create_collection(collection_name="test_collection")
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
configuration=None,
metadata={"hnsw:space": "cosine"},
embedding_function=client.embedding_function,
data_loader=None,
)
assert result == mock_collection
def test_get_or_create_collection_with_all_params(
self, client, mock_chromadb_client
):
"""Test get_or_create_collection with all optional parameters."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
mock_config = Mock()
mock_metadata = {"key": "value"}
mock_embedding_func = Mock()
mock_data_loader = Mock()
result = client.get_or_create_collection(
collection_name="test_collection",
configuration=mock_config,
metadata=mock_metadata,
embedding_function=mock_embedding_func,
data_loader=mock_data_loader,
)
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
configuration=mock_config,
metadata=mock_metadata,
embedding_function=mock_embedding_func,
data_loader=mock_data_loader,
)
assert result == mock_collection
@pytest.mark.asyncio
async def test_aget_or_create_collection(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test that aget_or_create_collection calls the underlying client correctly."""
mock_collection = Mock()
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
result = await async_client.aget_or_create_collection(
collection_name="test_collection"
)
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
configuration=None,
metadata={"hnsw:space": "cosine"},
embedding_function=async_client.embedding_function,
data_loader=None,
)
assert result == mock_collection
@pytest.mark.asyncio
async def test_aget_or_create_collection_with_all_params(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test aget_or_create_collection with all optional parameters."""
mock_collection = Mock()
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
mock_config = Mock()
mock_metadata = {"key": "value"}
mock_embedding_func = Mock()
mock_data_loader = Mock()
result = await async_client.aget_or_create_collection(
collection_name="test_collection",
configuration=mock_config,
metadata=mock_metadata,
embedding_function=mock_embedding_func,
data_loader=mock_data_loader,
)
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
name="test_collection",
configuration=mock_config,
metadata=mock_metadata,
embedding_function=mock_embedding_func,
data_loader=mock_data_loader,
)
assert result == mock_collection
def test_add_documents(self, client, mock_chromadb_client) -> None:
"""Test that add_documents adds documents to collection."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{
"content": "Test document",
"metadata": {"source": "test"},
}
]
client.add_documents(collection_name="test_collection", documents=documents)
mock_chromadb_client.get_collection.assert_called_once_with(
name="test_collection",
embedding_function=client.embedding_function,
)
# Verify documents were added to collection
mock_collection.add.assert_called_once()
call_args = mock_collection.add.call_args
assert len(call_args.kwargs["ids"]) == 1
assert call_args.kwargs["documents"] == ["Test document"]
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
"""Test add_documents with custom document IDs."""
mock_collection = Mock()
mock_chromadb_client.get_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{
"doc_id": "custom_id_1",
"content": "First document",
"metadata": {"source": "test1"},
},
{
"doc_id": "custom_id_2",
"content": "Second document",
"metadata": {"source": "test2"},
},
]
client.add_documents(collection_name="test_collection", documents=documents)
mock_collection.add.assert_called_once_with(
ids=["custom_id_1", "custom_id_2"],
documents=["First document", "Second document"],
metadatas=[{"source": "test1"}, {"source": "test2"}],
)
def test_add_documents_empty_list_raises_error(
self, client, mock_chromadb_client
) -> None:
"""Test that add_documents raises error for empty documents list."""
with pytest.raises(ValueError, match="Documents list cannot be empty"):
client.add_documents(collection_name="test_collection", documents=[])
@pytest.mark.asyncio
async def test_aadd_documents(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test that aadd_documents adds documents to collection asynchronously."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{
"content": "Test document",
"metadata": {"source": "test"},
}
]
await async_client.aadd_documents(
collection_name="test_collection", documents=documents
)
mock_async_chromadb_client.get_collection.assert_called_once_with(
name="test_collection",
embedding_function=async_client.embedding_function,
)
# Verify documents were added to collection
mock_collection.add.assert_called_once()
call_args = mock_collection.add.call_args
assert len(call_args.kwargs["ids"]) == 1
assert call_args.kwargs["documents"] == ["Test document"]
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
@pytest.mark.asyncio
async def test_aadd_documents_with_custom_ids(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test aadd_documents with custom document IDs."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{
"doc_id": "custom_id_1",
"content": "First document",
"metadata": {"source": "test1"},
},
{
"doc_id": "custom_id_2",
"content": "Second document",
"metadata": {"source": "test2"},
},
]
await async_client.aadd_documents(
collection_name="test_collection", documents=documents
)
mock_collection.add.assert_called_once_with(
ids=["custom_id_1", "custom_id_2"],
documents=["First document", "Second document"],
metadatas=[{"source": "test1"}, {"source": "test2"}],
)
@pytest.mark.asyncio
async def test_aadd_documents_empty_list_raises_error(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test that aadd_documents raises error for empty documents list."""
with pytest.raises(ValueError, match="Documents list cannot be empty"):
await async_client.aadd_documents(
collection_name="test_collection", documents=[]
)
def test_search(self, client, mock_chromadb_client):
"""Test that search queries the collection correctly."""
mock_collection = Mock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_chromadb_client.get_collection.return_value = mock_collection
mock_collection.query.return_value = {
"ids": [["doc1", "doc2"]],
"documents": [["Document 1", "Document 2"]],
"metadatas": [[{"source": "test1"}, {"source": "test2"}]],
"distances": [[0.1, 0.3]],
}
results = client.search(collection_name="test_collection", query="test query")
mock_chromadb_client.get_collection.assert_called_once_with(
name="test_collection",
embedding_function=client.embedding_function,
)
mock_collection.query.assert_called_once_with(
query_texts=["test query"],
n_results=10,
where=None,
where_document=None,
include=["metadatas", "documents", "distances"],
)
assert len(results) == 2
assert results[0]["id"] == "doc1"
assert results[0]["content"] == "Document 1"
assert results[0]["metadata"] == {"source": "test1"}
assert results[0]["score"] == 0.95
def test_search_with_optional_params(self, client, mock_chromadb_client):
"""Test search with optional parameters."""
mock_collection = Mock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_chromadb_client.get_collection.return_value = mock_collection
mock_collection.query.return_value = {
"ids": [["doc1", "doc2", "doc3"]],
"documents": [["Document 1", "Document 2", "Document 3"]],
"metadatas": [
[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}]
],
"distances": [[0.1, 0.3, 1.5]], # Last one will be filtered by threshold
}
results = client.search(
collection_name="test_collection",
query="test query",
limit=5,
metadata_filter={"source": "test"},
score_threshold=0.7,
)
mock_collection.query.assert_called_once_with(
query_texts=["test query"],
n_results=5,
where={"source": "test"},
where_document=None,
include=["metadatas", "documents", "distances"],
)
assert len(results) == 2
@pytest.mark.asyncio
async def test_asearch(self, async_client, mock_async_chromadb_client) -> None:
"""Test that asearch queries the collection correctly."""
mock_collection = AsyncMock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_async_chromadb_client.get_collection = AsyncMock(
return_value=mock_collection
)
mock_collection.query = AsyncMock(
return_value={
"ids": [["doc1", "doc2"]],
"documents": [["Document 1", "Document 2"]],
"metadatas": [[{"source": "test1"}, {"source": "test2"}]],
"distances": [[0.1, 0.3]],
}
)
results = await async_client.asearch(
collection_name="test_collection", query="test query"
)
mock_async_chromadb_client.get_collection.assert_called_once_with(
name="test_collection",
embedding_function=async_client.embedding_function,
)
mock_collection.query.assert_called_once_with(
query_texts=["test query"],
n_results=10,
where=None,
where_document=None,
include=["metadatas", "documents", "distances"],
)
assert len(results) == 2
assert results[0]["id"] == "doc1"
assert results[0]["content"] == "Document 1"
assert results[0]["metadata"] == {"source": "test1"}
assert results[0]["score"] == 0.95
@pytest.mark.asyncio
async def test_asearch_with_optional_params(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test asearch with optional parameters."""
mock_collection = AsyncMock()
mock_collection.metadata = {"hnsw:space": "cosine"}
mock_async_chromadb_client.get_collection = AsyncMock(
return_value=mock_collection
)
mock_collection.query = AsyncMock(
return_value={
"ids": [["doc1", "doc2", "doc3"]],
"documents": [["Document 1", "Document 2", "Document 3"]],
"metadatas": [
[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}]
],
"distances": [
[0.1, 0.3, 1.5]
], # Last one will be filtered by threshold
}
)
results = await async_client.asearch(
collection_name="test_collection",
query="test query",
limit=5,
metadata_filter={"source": "test"},
score_threshold=0.7,
)
mock_collection.query.assert_called_once_with(
query_texts=["test query"],
n_results=5,
where={"source": "test"},
where_document=None,
include=["metadatas", "documents", "distances"],
)
# Only 2 results should pass the score threshold
assert len(results) == 2
def test_delete_collection(self, client, mock_chromadb_client):
"""Test that delete_collection calls the underlying client correctly."""
client.delete_collection(collection_name="test_collection")
mock_chromadb_client.delete_collection.assert_called_once_with(
name="test_collection"
)
@pytest.mark.asyncio
async def test_adelete_collection(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test that adelete_collection calls the underlying client correctly."""
mock_async_chromadb_client.delete_collection = AsyncMock(return_value=None)
await async_client.adelete_collection(collection_name="test_collection")
mock_async_chromadb_client.delete_collection.assert_called_once_with(
name="test_collection"
)
def test_reset(self, client, mock_chromadb_client):
"""Test that reset calls the underlying client correctly."""
mock_chromadb_client.reset.return_value = True
client.reset()
mock_chromadb_client.reset.assert_called_once_with()
@pytest.mark.asyncio
async def test_areset(self, async_client, mock_async_chromadb_client) -> None:
"""Test that areset calls the underlying client correctly."""
mock_async_chromadb_client.reset = AsyncMock(return_value=True)
await async_client.areset()
mock_async_chromadb_client.reset.assert_called_once_with()