mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: chromadb generic client (#3374)
Add ChromaDB client implementation with async support - Implement core collection operations (create, get_or_create, delete) - Add search functionality with cosine similarity scoring - Include both sync and async method variants - Add type safety with NamedTuples and TypeGuards - Extract utility functions to separate modules - Default to cosine distance metric for text similarity - Add comprehensive test coverage TODO: - l2, ip score calculations are not settled on
This commit is contained in:
0
src/crewai/rag/chromadb/__init__.py
Normal file
0
src/crewai/rag/chromadb/__init__.py
Normal file
556
src/crewai/rag/chromadb/client.py
Normal file
556
src/crewai/rag/chromadb/client.py
Normal file
@@ -0,0 +1,556 @@
|
|||||||
|
"""ChromaDB client implementation."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from chromadb.api.types import (
|
||||||
|
Embeddable,
|
||||||
|
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||||
|
QueryResult,
|
||||||
|
)
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
|
from crewai.rag.chromadb.types import (
|
||||||
|
ChromaDBClientType,
|
||||||
|
ChromaDBCollectionCreateParams,
|
||||||
|
ChromaDBCollectionSearchParams,
|
||||||
|
)
|
||||||
|
from crewai.rag.chromadb.utils import (
|
||||||
|
_extract_search_params,
|
||||||
|
_is_async_client,
|
||||||
|
_is_sync_client,
|
||||||
|
_prepare_documents_for_chromadb,
|
||||||
|
_process_query_results,
|
||||||
|
)
|
||||||
|
from crewai.rag.core.base_client import (
|
||||||
|
BaseClient,
|
||||||
|
BaseCollectionParams,
|
||||||
|
BaseCollectionAddParams,
|
||||||
|
)
|
||||||
|
from crewai.rag.types import SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaDBClient(BaseClient):
|
||||||
|
"""ChromaDB implementation of the BaseClient protocol.
|
||||||
|
|
||||||
|
Provides vector database operations for ChromaDB, supporting both
|
||||||
|
synchronous and asynchronous clients.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
||||||
|
embedding_function: Function to generate embeddings for documents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
client: ChromaDBClientType
|
||||||
|
embedding_function: ChromaEmbeddingFunction[Embeddable]
|
||||||
|
|
||||||
|
def create_collection(
|
||||||
|
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||||
|
) -> None:
|
||||||
|
"""Create a new collection in ChromaDB.
|
||||||
|
|
||||||
|
Uses the client's default embedding function if none provided.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
collection_name: Name of the collection to create. Must be unique.
|
||||||
|
configuration: Optional collection configuration specifying distance metrics,
|
||||||
|
HNSW parameters, or other backend-specific settings.
|
||||||
|
metadata: Optional metadata dictionary to attach to the collection.
|
||||||
|
embedding_function: Optional custom embedding function. If not provided,
|
||||||
|
uses the client's default embedding function.
|
||||||
|
data_loader: Optional data loader for batch loading data into the collection.
|
||||||
|
get_or_create: If True, returns existing collection if it already exists
|
||||||
|
instead of raising an error. Defaults to False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||||
|
ValueError: If collection with the same name already exists and get_or_create
|
||||||
|
is False.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = ChromaDBClient()
|
||||||
|
>>> client.create_collection(
|
||||||
|
... collection_name="documents",
|
||||||
|
... metadata={"description": "Product documentation"},
|
||||||
|
... get_or_create=True
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
if not _is_sync_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Synchronous method create_collection() requires a ClientAPI. "
|
||||||
|
"Use acreate_collection() for AsyncClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
if "hnsw:space" not in metadata:
|
||||||
|
metadata["hnsw:space"] = "cosine"
|
||||||
|
|
||||||
|
self.client.create_collection(
|
||||||
|
name=kwargs["collection_name"],
|
||||||
|
configuration=kwargs.get("configuration"),
|
||||||
|
metadata=metadata,
|
||||||
|
embedding_function=kwargs.get(
|
||||||
|
"embedding_function", self.embedding_function
|
||||||
|
),
|
||||||
|
data_loader=kwargs.get("data_loader"),
|
||||||
|
get_or_create=kwargs.get("get_or_create", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def acreate_collection(
|
||||||
|
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||||
|
) -> None:
|
||||||
|
"""Create a new collection in ChromaDB asynchronously.
|
||||||
|
|
||||||
|
Creates a new collection with the specified name and optional configuration.
|
||||||
|
If an embedding function is not provided, uses the client's default embedding function.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
collection_name: Name of the collection to create. Must be unique.
|
||||||
|
configuration: Optional collection configuration specifying distance metrics,
|
||||||
|
HNSW parameters, or other backend-specific settings.
|
||||||
|
metadata: Optional metadata dictionary to attach to the collection.
|
||||||
|
embedding_function: Optional custom embedding function. If not provided,
|
||||||
|
uses the client's default embedding function.
|
||||||
|
data_loader: Optional data loader for batch loading data into the collection.
|
||||||
|
get_or_create: If True, returns existing collection if it already exists
|
||||||
|
instead of raising an error. Defaults to False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||||
|
ValueError: If collection with the same name already exists and get_or_create
|
||||||
|
is False.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import asyncio
|
||||||
|
>>> async def main():
|
||||||
|
... client = ChromaDBClient()
|
||||||
|
... await client.acreate_collection(
|
||||||
|
... collection_name="documents",
|
||||||
|
... metadata={"description": "Product documentation"},
|
||||||
|
... get_or_create=True
|
||||||
|
... )
|
||||||
|
>>> asyncio.run(main())
|
||||||
|
"""
|
||||||
|
if not _is_async_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Asynchronous method acreate_collection() requires an AsyncClientAPI. "
|
||||||
|
"Use create_collection() for ClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
if "hnsw:space" not in metadata:
|
||||||
|
metadata["hnsw:space"] = "cosine"
|
||||||
|
|
||||||
|
await self.client.create_collection(
|
||||||
|
name=kwargs["collection_name"],
|
||||||
|
configuration=kwargs.get("configuration"),
|
||||||
|
metadata=metadata,
|
||||||
|
embedding_function=kwargs.get(
|
||||||
|
"embedding_function", self.embedding_function
|
||||||
|
),
|
||||||
|
data_loader=kwargs.get("data_loader"),
|
||||||
|
get_or_create=kwargs.get("get_or_create", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_or_create_collection(
|
||||||
|
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||||
|
) -> Any:
|
||||||
|
"""Get an existing collection or create it if it doesn't exist.
|
||||||
|
|
||||||
|
Returns existing collection if found, otherwise creates a new one.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
collection_name: Name of the collection to get or create.
|
||||||
|
configuration: Optional collection configuration specifying distance metrics,
|
||||||
|
HNSW parameters, or other backend-specific settings.
|
||||||
|
metadata: Optional metadata dictionary to attach to the collection.
|
||||||
|
embedding_function: Optional custom embedding function. If not provided,
|
||||||
|
uses the client's default embedding function.
|
||||||
|
data_loader: Optional data loader for batch loading data into the collection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ChromaDB Collection object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = ChromaDBClient()
|
||||||
|
>>> collection = client.get_or_create_collection(
|
||||||
|
... collection_name="documents",
|
||||||
|
... metadata={"description": "Product documentation"}
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
if not _is_sync_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Synchronous method get_or_create_collection() requires a ClientAPI. "
|
||||||
|
"Use aget_or_create_collection() for AsyncClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
if "hnsw:space" not in metadata:
|
||||||
|
metadata["hnsw:space"] = "cosine"
|
||||||
|
|
||||||
|
return self.client.get_or_create_collection(
|
||||||
|
name=kwargs["collection_name"],
|
||||||
|
configuration=kwargs.get("configuration"),
|
||||||
|
metadata=metadata,
|
||||||
|
embedding_function=kwargs.get(
|
||||||
|
"embedding_function", self.embedding_function
|
||||||
|
),
|
||||||
|
data_loader=kwargs.get("data_loader"),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aget_or_create_collection(
|
||||||
|
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||||
|
) -> Any:
|
||||||
|
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||||
|
|
||||||
|
Returns existing collection if found, otherwise creates a new one.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
collection_name: Name of the collection to get or create.
|
||||||
|
configuration: Optional collection configuration specifying distance metrics,
|
||||||
|
HNSW parameters, or other backend-specific settings.
|
||||||
|
metadata: Optional metadata dictionary to attach to the collection.
|
||||||
|
embedding_function: Optional custom embedding function. If not provided,
|
||||||
|
uses the client's default embedding function.
|
||||||
|
data_loader: Optional data loader for batch loading data into the collection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ChromaDB AsyncCollection object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import asyncio
|
||||||
|
>>> async def main():
|
||||||
|
... client = ChromaDBClient()
|
||||||
|
... collection = await client.aget_or_create_collection(
|
||||||
|
... collection_name="documents",
|
||||||
|
... metadata={"description": "Product documentation"}
|
||||||
|
... )
|
||||||
|
>>> asyncio.run(main())
|
||||||
|
"""
|
||||||
|
if not _is_async_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Asynchronous method aget_or_create_collection() requires an AsyncClientAPI. "
|
||||||
|
"Use get_or_create_collection() for ClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
if "hnsw:space" not in metadata:
|
||||||
|
metadata["hnsw:space"] = "cosine"
|
||||||
|
|
||||||
|
return await self.client.get_or_create_collection(
|
||||||
|
name=kwargs["collection_name"],
|
||||||
|
configuration=kwargs.get("configuration"),
|
||||||
|
metadata=metadata,
|
||||||
|
embedding_function=kwargs.get(
|
||||||
|
"embedding_function", self.embedding_function
|
||||||
|
),
|
||||||
|
data_loader=kwargs.get("data_loader"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||||
|
"""Add documents with their embeddings to a collection.
|
||||||
|
|
||||||
|
Performs an upsert operation - documents with existing IDs are updated.
|
||||||
|
Generates embeddings automatically using the configured embedding function.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
collection_name: The name of the collection to add documents to.
|
||||||
|
documents: List of BaseRecord dicts containing:
|
||||||
|
- content: The text content (required)
|
||||||
|
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||||
|
- metadata: Optional metadata dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||||
|
ValueError: If collection doesn't exist or documents list is empty.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
"""
|
||||||
|
if not _is_sync_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Synchronous method add_documents() requires a ClientAPI. "
|
||||||
|
"Use aadd_documents() for AsyncClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
collection_name = kwargs["collection_name"]
|
||||||
|
documents = kwargs["documents"]
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
raise ValueError("Documents list cannot be empty")
|
||||||
|
|
||||||
|
collection = self.client.get_collection(
|
||||||
|
name=collection_name,
|
||||||
|
embedding_function=self.embedding_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
prepared = _prepare_documents_for_chromadb(documents)
|
||||||
|
collection.add(
|
||||||
|
ids=prepared.ids,
|
||||||
|
documents=prepared.texts,
|
||||||
|
metadatas=prepared.metadatas,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||||
|
"""Add documents with their embeddings to a collection asynchronously.
|
||||||
|
|
||||||
|
Performs an upsert operation - documents with existing IDs are updated.
|
||||||
|
Generates embeddings automatically using the configured embedding function.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
collection_name: The name of the collection to add documents to.
|
||||||
|
documents: List of BaseRecord dicts containing:
|
||||||
|
- content: The text content (required)
|
||||||
|
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||||
|
- metadata: Optional metadata dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||||
|
ValueError: If collection doesn't exist or documents list is empty.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
"""
|
||||||
|
if not _is_async_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Asynchronous method aadd_documents() requires an AsyncClientAPI. "
|
||||||
|
"Use add_documents() for ClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
collection_name = kwargs["collection_name"]
|
||||||
|
documents = kwargs["documents"]
|
||||||
|
|
||||||
|
if not documents:
|
||||||
|
raise ValueError("Documents list cannot be empty")
|
||||||
|
|
||||||
|
collection = await self.client.get_collection(
|
||||||
|
name=collection_name,
|
||||||
|
embedding_function=self.embedding_function,
|
||||||
|
)
|
||||||
|
prepared = _prepare_documents_for_chromadb(documents)
|
||||||
|
await collection.add(
|
||||||
|
ids=prepared.ids,
|
||||||
|
documents=prepared.texts,
|
||||||
|
metadatas=prepared.metadatas,
|
||||||
|
)
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search for similar documents using a query.
|
||||||
|
|
||||||
|
Performs semantic search to find documents similar to the query text.
|
||||||
|
Uses the configured embedding function to generate query embeddings.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
collection_name: Name of the collection to search in.
|
||||||
|
query: The text query to search for.
|
||||||
|
limit: Maximum number of results to return (default: 10).
|
||||||
|
metadata_filter: Optional filter for metadata fields.
|
||||||
|
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||||
|
where: Optional ChromaDB where clause for metadata filtering.
|
||||||
|
where_document: Optional ChromaDB where clause for document content filtering.
|
||||||
|
include: Optional list of fields to include in results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult dicts containing id, content, metadata, and score.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||||
|
ValueError: If collection doesn't exist.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
"""
|
||||||
|
if not _is_sync_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Synchronous method search() requires a ClientAPI. "
|
||||||
|
"Use asearch() for AsyncClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
params = _extract_search_params(kwargs)
|
||||||
|
|
||||||
|
collection = self.client.get_collection(
|
||||||
|
name=params.collection_name,
|
||||||
|
embedding_function=self.embedding_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
where = params.where if params.where is not None else params.metadata_filter
|
||||||
|
|
||||||
|
results: QueryResult = collection.query(
|
||||||
|
query_texts=[params.query],
|
||||||
|
n_results=params.limit,
|
||||||
|
where=where,
|
||||||
|
where_document=params.where_document,
|
||||||
|
include=params.include,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _process_query_results(
|
||||||
|
collection=collection,
|
||||||
|
results=results,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def asearch(
|
||||||
|
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Search for similar documents using a query asynchronously.
|
||||||
|
|
||||||
|
Performs semantic search to find documents similar to the query text.
|
||||||
|
Uses the configured embedding function to generate query embeddings.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
collection_name: Name of the collection to search in.
|
||||||
|
query: The text query to search for.
|
||||||
|
limit: Maximum number of results to return (default: 10).
|
||||||
|
metadata_filter: Optional filter for metadata fields.
|
||||||
|
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||||
|
where: Optional ChromaDB where clause for metadata filtering.
|
||||||
|
where_document: Optional ChromaDB where clause for document content filtering.
|
||||||
|
include: Optional list of fields to include in results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult dicts containing id, content, metadata, and score.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||||
|
ValueError: If collection doesn't exist.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
"""
|
||||||
|
if not _is_async_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Asynchronous method asearch() requires an AsyncClientAPI. "
|
||||||
|
"Use search() for ClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
params = _extract_search_params(kwargs)
|
||||||
|
|
||||||
|
collection = await self.client.get_collection(
|
||||||
|
name=params.collection_name,
|
||||||
|
embedding_function=self.embedding_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
where = params.where if params.where is not None else params.metadata_filter
|
||||||
|
|
||||||
|
results: QueryResult = await collection.query(
|
||||||
|
query_texts=[params.query],
|
||||||
|
n_results=params.limit,
|
||||||
|
where=where,
|
||||||
|
where_document=params.where_document,
|
||||||
|
include=params.include,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _process_query_results(
|
||||||
|
collection=collection,
|
||||||
|
results=results,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||||
|
"""Delete a collection and all its data.
|
||||||
|
|
||||||
|
Permanently removes a collection and all documents, embeddings, and metadata it contains.
|
||||||
|
This operation cannot be undone.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
collection_name: Name of the collection to delete.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||||
|
ValueError: If collection doesn't exist.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = ChromaDBClient()
|
||||||
|
>>> client.delete_collection(collection_name="old_documents")
|
||||||
|
"""
|
||||||
|
if not _is_sync_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Synchronous method delete_collection() requires a ClientAPI. "
|
||||||
|
"Use adelete_collection() for AsyncClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
collection_name = kwargs["collection_name"]
|
||||||
|
self.client.delete_collection(name=collection_name)
|
||||||
|
|
||||||
|
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||||
|
"""Delete a collection and all its data asynchronously.
|
||||||
|
|
||||||
|
Permanently removes a collection and all documents, embeddings, and metadata it contains.
|
||||||
|
This operation cannot be undone.
|
||||||
|
|
||||||
|
Keyword Args:
|
||||||
|
collection_name: Name of the collection to delete.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||||
|
ValueError: If collection doesn't exist.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import asyncio
|
||||||
|
>>> async def main():
|
||||||
|
... client = ChromaDBClient()
|
||||||
|
... await client.adelete_collection(collection_name="old_documents")
|
||||||
|
>>> asyncio.run(main())
|
||||||
|
"""
|
||||||
|
if not _is_async_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Asynchronous method adelete_collection() requires an AsyncClientAPI. "
|
||||||
|
"Use delete_collection() for ClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
collection_name = kwargs["collection_name"]
|
||||||
|
await self.client.delete_collection(name=collection_name)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the vector database by deleting all collections and data.
|
||||||
|
|
||||||
|
Completely clears the ChromaDB instance, removing all collections,
|
||||||
|
documents, embeddings, and metadata. This operation cannot be undone.
|
||||||
|
Use with extreme caution in production environments.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> client = ChromaDBClient()
|
||||||
|
>>> client.reset() # Removes ALL data from ChromaDB
|
||||||
|
"""
|
||||||
|
if not _is_sync_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Synchronous method reset() requires a ClientAPI. "
|
||||||
|
"Use areset() for AsyncClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client.reset()
|
||||||
|
|
||||||
|
async def areset(self) -> None:
|
||||||
|
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||||
|
|
||||||
|
Completely clears the ChromaDB instance, removing all collections,
|
||||||
|
documents, embeddings, and metadata. This operation cannot be undone.
|
||||||
|
Use with extreme caution in production environments.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||||
|
ConnectionError: If unable to connect to ChromaDB server.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> import asyncio
|
||||||
|
>>> async def main():
|
||||||
|
... client = ChromaDBClient()
|
||||||
|
... await client.areset() # Removes ALL data from ChromaDB
|
||||||
|
>>> asyncio.run(main())
|
||||||
|
"""
|
||||||
|
if not _is_async_client(self.client):
|
||||||
|
raise TypeError(
|
||||||
|
"Asynchronous method areset() requires an AsyncClientAPI. "
|
||||||
|
"Use reset() for ClientAPI."
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.client.reset()
|
||||||
85
src/crewai/rag/chromadb/types.py
Normal file
85
src/crewai/rag/chromadb/types.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Type definitions specific to ChromaDB implementation."""
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
from chromadb.api import ClientAPI, AsyncClientAPI
|
||||||
|
from chromadb.api.configuration import CollectionConfigurationInterface
|
||||||
|
from chromadb.api.types import (
|
||||||
|
CollectionMetadata,
|
||||||
|
DataLoader,
|
||||||
|
Embeddable,
|
||||||
|
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||||
|
Include,
|
||||||
|
Loadable,
|
||||||
|
Where,
|
||||||
|
WhereDocument,
|
||||||
|
)
|
||||||
|
|
||||||
|
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
|
||||||
|
|
||||||
|
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||||
|
|
||||||
|
|
||||||
|
class PreparedDocuments(NamedTuple):
|
||||||
|
"""Prepared documents ready for ChromaDB insertion.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
ids: List of document IDs
|
||||||
|
texts: List of document texts
|
||||||
|
metadatas: List of document metadata mappings
|
||||||
|
"""
|
||||||
|
|
||||||
|
ids: list[str]
|
||||||
|
texts: list[str]
|
||||||
|
metadatas: list[Mapping[str, str | int | float | bool]]
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractedSearchParams(NamedTuple):
|
||||||
|
"""Extracted search parameters for ChromaDB queries.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
collection_name: Name of the collection to search
|
||||||
|
query: Search query text
|
||||||
|
limit: Maximum number of results
|
||||||
|
metadata_filter: Optional metadata filter
|
||||||
|
score_threshold: Optional minimum similarity score
|
||||||
|
where: Optional ChromaDB where clause
|
||||||
|
where_document: Optional ChromaDB document filter
|
||||||
|
include: Fields to include in results
|
||||||
|
"""
|
||||||
|
|
||||||
|
collection_name: str
|
||||||
|
query: str
|
||||||
|
limit: int
|
||||||
|
metadata_filter: dict[str, Any] | None
|
||||||
|
score_threshold: float | None
|
||||||
|
where: Where | None
|
||||||
|
where_document: WhereDocument | None
|
||||||
|
include: Include
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
|
||||||
|
"""Parameters for creating a ChromaDB collection.
|
||||||
|
|
||||||
|
This class extends BaseCollectionParams to include any additional
|
||||||
|
parameters specific to ChromaDB collection creation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
configuration: CollectionConfigurationInterface
|
||||||
|
metadata: CollectionMetadata
|
||||||
|
embedding_function: ChromaEmbeddingFunction[Embeddable]
|
||||||
|
data_loader: DataLoader[Loadable]
|
||||||
|
get_or_create: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaDBCollectionSearchParams(BaseCollectionSearchParams, total=False):
|
||||||
|
"""Parameters for searching a ChromaDB collection.
|
||||||
|
|
||||||
|
This class extends BaseCollectionSearchParams to include ChromaDB-specific
|
||||||
|
search parameters like where clauses and include options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
where: Where
|
||||||
|
where_document: WhereDocument
|
||||||
|
include: Include
|
||||||
220
src/crewai/rag/chromadb/utils.py
Normal file
220
src/crewai/rag/chromadb/utils.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""Utility functions for ChromaDB client implementation."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Literal, TypeGuard, cast
|
||||||
|
|
||||||
|
from chromadb.api import AsyncClientAPI, ClientAPI
|
||||||
|
from chromadb.api.types import (
|
||||||
|
Include,
|
||||||
|
IncludeEnum,
|
||||||
|
QueryResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||||
|
from chromadb.api.models.Collection import Collection
|
||||||
|
|
||||||
|
from crewai.rag.chromadb.types import (
|
||||||
|
ChromaDBClientType,
|
||||||
|
ChromaDBCollectionSearchParams,
|
||||||
|
ExtractedSearchParams,
|
||||||
|
PreparedDocuments,
|
||||||
|
)
|
||||||
|
from crewai.rag.types import BaseRecord, SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]:
|
||||||
|
"""Type guard to check if the client is a synchronous ClientAPI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: The client to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the client is a ClientAPI, False otherwise.
|
||||||
|
"""
|
||||||
|
return isinstance(client, ClientAPI)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]:
|
||||||
|
"""Type guard to check if the client is an asynchronous AsyncClientAPI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: The client to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the client is an AsyncClientAPI, False otherwise.
|
||||||
|
"""
|
||||||
|
return isinstance(client, AsyncClientAPI)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_documents_for_chromadb(
|
||||||
|
documents: list[BaseRecord],
|
||||||
|
) -> PreparedDocuments:
|
||||||
|
"""Prepare documents for ChromaDB by extracting IDs, texts, and metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
documents: List of BaseRecord documents to prepare.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PreparedDocuments with ids, texts, and metadatas ready for ChromaDB.
|
||||||
|
"""
|
||||||
|
ids: list[str] = []
|
||||||
|
texts: list[str] = []
|
||||||
|
metadatas: list[Mapping[str, str | int | float | bool]] = []
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
if "doc_id" in doc:
|
||||||
|
ids.append(doc["doc_id"])
|
||||||
|
else:
|
||||||
|
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
|
||||||
|
ids.append(content_hash)
|
||||||
|
|
||||||
|
texts.append(doc["content"])
|
||||||
|
metadata = doc.get("metadata")
|
||||||
|
if metadata:
|
||||||
|
if isinstance(metadata, list):
|
||||||
|
metadatas.append(metadata[0] if metadata else {})
|
||||||
|
else:
|
||||||
|
metadatas.append(metadata)
|
||||||
|
else:
|
||||||
|
metadatas.append({})
|
||||||
|
|
||||||
|
return PreparedDocuments(ids, texts, metadatas)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_search_params(
|
||||||
|
kwargs: ChromaDBCollectionSearchParams,
|
||||||
|
) -> ExtractedSearchParams:
|
||||||
|
"""Extract search parameters from kwargs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Keyword arguments containing search parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtractedSearchParams with all extracted parameters.
|
||||||
|
"""
|
||||||
|
return ExtractedSearchParams(
|
||||||
|
collection_name=kwargs["collection_name"],
|
||||||
|
query=kwargs["query"],
|
||||||
|
limit=kwargs.get("limit", 10),
|
||||||
|
metadata_filter=kwargs.get("metadata_filter"),
|
||||||
|
score_threshold=kwargs.get("score_threshold"),
|
||||||
|
where=kwargs.get("where"),
|
||||||
|
where_document=kwargs.get("where_document"),
|
||||||
|
include=kwargs.get(
|
||||||
|
"include",
|
||||||
|
[IncludeEnum.metadatas, IncludeEnum.documents, IncludeEnum.distances],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_distance_to_score(
|
||||||
|
distance: float,
|
||||||
|
distance_metric: Literal["l2", "cosine", "ip"],
|
||||||
|
) -> float:
|
||||||
|
"""Convert ChromaDB distance to similarity score.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Assuming all embedding are unit-normalized for now, including custom embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distance: The distance value from ChromaDB.
|
||||||
|
distance_metric: The distance metric used ("l2", "cosine", or "ip").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Similarity score in range [0, 1] where 1 is most similar.
|
||||||
|
"""
|
||||||
|
if distance_metric == "cosine":
|
||||||
|
score = 1.0 - 0.5 * distance
|
||||||
|
return max(0.0, min(1.0, score))
|
||||||
|
raise ValueError(f"Unsupported distance metric: {distance_metric}")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_chromadb_results_to_search_results(
|
||||||
|
results: QueryResult,
|
||||||
|
include: Include,
|
||||||
|
distance_metric: Literal["l2", "cosine", "ip"],
|
||||||
|
score_threshold: float | None = None,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Convert ChromaDB query results to SearchResult format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: ChromaDB query results.
|
||||||
|
include: List of fields that were included in the query.
|
||||||
|
distance_metric: The distance metric used by the collection.
|
||||||
|
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult dicts containing id, content, metadata, and score.
|
||||||
|
"""
|
||||||
|
search_results: list[SearchResult] = []
|
||||||
|
|
||||||
|
include_strings = [item.value for item in include]
|
||||||
|
|
||||||
|
ids = results["ids"][0] if results.get("ids") else []
|
||||||
|
|
||||||
|
documents_list = results.get("documents")
|
||||||
|
documents = (
|
||||||
|
documents_list[0] if documents_list and "documents" in include_strings else []
|
||||||
|
)
|
||||||
|
|
||||||
|
metadatas_list = results.get("metadatas")
|
||||||
|
metadatas = (
|
||||||
|
metadatas_list[0] if metadatas_list and "metadatas" in include_strings else []
|
||||||
|
)
|
||||||
|
|
||||||
|
distances_list = results.get("distances")
|
||||||
|
distances = (
|
||||||
|
distances_list[0] if distances_list and "distances" in include_strings else []
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, doc_id in enumerate(ids):
|
||||||
|
if not distances or i >= len(distances):
|
||||||
|
continue
|
||||||
|
|
||||||
|
distance = distances[i]
|
||||||
|
score = _convert_distance_to_score(
|
||||||
|
distance=distance, distance_metric=distance_metric
|
||||||
|
)
|
||||||
|
|
||||||
|
if score_threshold and score < score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result: SearchResult = {
|
||||||
|
"id": doc_id,
|
||||||
|
"content": documents[i] if documents and i < len(documents) else "",
|
||||||
|
"metadata": dict(metadatas[i]) if metadatas and i < len(metadatas) else {},
|
||||||
|
"score": score,
|
||||||
|
}
|
||||||
|
search_results.append(result)
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
|
||||||
|
def _process_query_results(
|
||||||
|
collection: Collection | AsyncCollection,
|
||||||
|
results: QueryResult,
|
||||||
|
params: ExtractedSearchParams,
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
"""Process ChromaDB query results and convert to SearchResult format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection: The ChromaDB collection (sync or async) that was queried.
|
||||||
|
results: Raw query results from ChromaDB.
|
||||||
|
params: The search parameters used for the query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult dicts containing id, content, metadata, and score.
|
||||||
|
"""
|
||||||
|
|
||||||
|
distance_metric = cast(
|
||||||
|
Literal["l2", "cosine", "ip"],
|
||||||
|
collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2",
|
||||||
|
)
|
||||||
|
|
||||||
|
return _convert_chromadb_results_to_search_results(
|
||||||
|
results=results,
|
||||||
|
include=params.include,
|
||||||
|
distance_metric=distance_metric,
|
||||||
|
score_threshold=params.score_threshold,
|
||||||
|
)
|
||||||
0
tests/rag/__init__.py
Normal file
0
tests/rag/__init__.py
Normal file
0
tests/rag/chromadb/__init__.py
Normal file
0
tests/rag/chromadb/__init__.py
Normal file
550
tests/rag/chromadb/test_client.py
Normal file
550
tests/rag/chromadb/test_client.py
Normal file
@@ -0,0 +1,550 @@
|
|||||||
|
"""Tests for ChromaDBClient implementation."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.rag.chromadb.client import ChromaDBClient
|
||||||
|
from crewai.rag.types import BaseRecord
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_chromadb_client():
|
||||||
|
"""Create a mock ChromaDB client."""
|
||||||
|
from chromadb.api import ClientAPI
|
||||||
|
|
||||||
|
return Mock(spec=ClientAPI)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_async_chromadb_client():
|
||||||
|
"""Create a mock async ChromaDB client."""
|
||||||
|
from chromadb.api import AsyncClientAPI
|
||||||
|
|
||||||
|
return Mock(spec=AsyncClientAPI)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(mock_chromadb_client) -> ChromaDBClient:
|
||||||
|
"""Create a ChromaDBClient instance for testing."""
|
||||||
|
client = ChromaDBClient()
|
||||||
|
client.client = mock_chromadb_client
|
||||||
|
client.embedding_function = Mock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
|
||||||
|
"""Create a ChromaDBClient instance with async client for testing."""
|
||||||
|
client = ChromaDBClient()
|
||||||
|
client.client = mock_async_chromadb_client
|
||||||
|
client.embedding_function = Mock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
class TestChromaDBClient:
|
||||||
|
"""Test suite for ChromaDBClient."""
|
||||||
|
|
||||||
|
def test_create_collection(self, client, mock_chromadb_client):
|
||||||
|
"""Test that create_collection calls the underlying client correctly."""
|
||||||
|
client.create_collection(collection_name="test_collection")
|
||||||
|
|
||||||
|
mock_chromadb_client.create_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
configuration=None,
|
||||||
|
metadata={"hnsw:space": "cosine"},
|
||||||
|
embedding_function=client.embedding_function,
|
||||||
|
data_loader=None,
|
||||||
|
get_or_create=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_create_collection_with_all_params(self, client, mock_chromadb_client):
|
||||||
|
"""Test create_collection with all optional parameters."""
|
||||||
|
mock_config = Mock()
|
||||||
|
mock_metadata = {"key": "value"}
|
||||||
|
mock_embedding_func = Mock()
|
||||||
|
mock_data_loader = Mock()
|
||||||
|
|
||||||
|
client.create_collection(
|
||||||
|
collection_name="test_collection",
|
||||||
|
configuration=mock_config,
|
||||||
|
metadata=mock_metadata,
|
||||||
|
embedding_function=mock_embedding_func,
|
||||||
|
data_loader=mock_data_loader,
|
||||||
|
get_or_create=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_chromadb_client.create_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
configuration=mock_config,
|
||||||
|
metadata=mock_metadata,
|
||||||
|
embedding_function=mock_embedding_func,
|
||||||
|
data_loader=mock_data_loader,
|
||||||
|
get_or_create=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acreate_collection(
|
||||||
|
self, async_client, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test that acreate_collection calls the underlying client correctly."""
|
||||||
|
# Make the mock's create_collection an AsyncMock
|
||||||
|
mock_async_chromadb_client.create_collection = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
await async_client.acreate_collection(collection_name="test_collection")
|
||||||
|
|
||||||
|
mock_async_chromadb_client.create_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
configuration=None,
|
||||||
|
metadata={"hnsw:space": "cosine"},
|
||||||
|
embedding_function=async_client.embedding_function,
|
||||||
|
data_loader=None,
|
||||||
|
get_or_create=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acreate_collection_with_all_params(
|
||||||
|
self, async_client, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test acreate_collection with all optional parameters."""
|
||||||
|
# Make the mock's create_collection an AsyncMock
|
||||||
|
mock_async_chromadb_client.create_collection = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
mock_config = Mock()
|
||||||
|
mock_metadata = {"key": "value"}
|
||||||
|
mock_embedding_func = Mock()
|
||||||
|
mock_data_loader = Mock()
|
||||||
|
|
||||||
|
await async_client.acreate_collection(
|
||||||
|
collection_name="test_collection",
|
||||||
|
configuration=mock_config,
|
||||||
|
metadata=mock_metadata,
|
||||||
|
embedding_function=mock_embedding_func,
|
||||||
|
data_loader=mock_data_loader,
|
||||||
|
get_or_create=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_async_chromadb_client.create_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
configuration=mock_config,
|
||||||
|
metadata=mock_metadata,
|
||||||
|
embedding_function=mock_embedding_func,
|
||||||
|
data_loader=mock_data_loader,
|
||||||
|
get_or_create=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_or_create_collection(self, client, mock_chromadb_client):
|
||||||
|
"""Test that get_or_create_collection calls the underlying client correctly."""
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
result = client.get_or_create_collection(collection_name="test_collection")
|
||||||
|
|
||||||
|
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
configuration=None,
|
||||||
|
metadata={"hnsw:space": "cosine"},
|
||||||
|
embedding_function=client.embedding_function,
|
||||||
|
data_loader=None,
|
||||||
|
)
|
||||||
|
assert result == mock_collection
|
||||||
|
|
||||||
|
def test_get_or_create_collection_with_all_params(
|
||||||
|
self, client, mock_chromadb_client
|
||||||
|
):
|
||||||
|
"""Test get_or_create_collection with all optional parameters."""
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
mock_config = Mock()
|
||||||
|
mock_metadata = {"key": "value"}
|
||||||
|
mock_embedding_func = Mock()
|
||||||
|
mock_data_loader = Mock()
|
||||||
|
|
||||||
|
result = client.get_or_create_collection(
|
||||||
|
collection_name="test_collection",
|
||||||
|
configuration=mock_config,
|
||||||
|
metadata=mock_metadata,
|
||||||
|
embedding_function=mock_embedding_func,
|
||||||
|
data_loader=mock_data_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
configuration=mock_config,
|
||||||
|
metadata=mock_metadata,
|
||||||
|
embedding_function=mock_embedding_func,
|
||||||
|
data_loader=mock_data_loader,
|
||||||
|
)
|
||||||
|
assert result == mock_collection
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aget_or_create_collection(
|
||||||
|
self, async_client, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test that aget_or_create_collection calls the underlying client correctly."""
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||||
|
return_value=mock_collection
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await async_client.aget_or_create_collection(
|
||||||
|
collection_name="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
configuration=None,
|
||||||
|
metadata={"hnsw:space": "cosine"},
|
||||||
|
embedding_function=async_client.embedding_function,
|
||||||
|
data_loader=None,
|
||||||
|
)
|
||||||
|
assert result == mock_collection
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aget_or_create_collection_with_all_params(
|
||||||
|
self, async_client, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test aget_or_create_collection with all optional parameters."""
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||||
|
return_value=mock_collection
|
||||||
|
)
|
||||||
|
mock_config = Mock()
|
||||||
|
mock_metadata = {"key": "value"}
|
||||||
|
mock_embedding_func = Mock()
|
||||||
|
mock_data_loader = Mock()
|
||||||
|
|
||||||
|
result = await async_client.aget_or_create_collection(
|
||||||
|
collection_name="test_collection",
|
||||||
|
configuration=mock_config,
|
||||||
|
metadata=mock_metadata,
|
||||||
|
embedding_function=mock_embedding_func,
|
||||||
|
data_loader=mock_data_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_async_chromadb_client.get_or_create_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
configuration=mock_config,
|
||||||
|
metadata=mock_metadata,
|
||||||
|
embedding_function=mock_embedding_func,
|
||||||
|
data_loader=mock_data_loader,
|
||||||
|
)
|
||||||
|
assert result == mock_collection
|
||||||
|
|
||||||
|
def test_add_documents(self, client, mock_chromadb_client) -> None:
|
||||||
|
"""Test that add_documents adds documents to collection."""
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{
|
||||||
|
"content": "Test document",
|
||||||
|
"metadata": {"source": "test"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
client.add_documents(collection_name="test_collection", documents=documents)
|
||||||
|
|
||||||
|
mock_chromadb_client.get_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
embedding_function=client.embedding_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify documents were added to collection
|
||||||
|
mock_collection.add.assert_called_once()
|
||||||
|
call_args = mock_collection.add.call_args
|
||||||
|
assert len(call_args.kwargs["ids"]) == 1
|
||||||
|
assert call_args.kwargs["documents"] == ["Test document"]
|
||||||
|
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
||||||
|
|
||||||
|
def test_add_documents_with_custom_ids(self, client, mock_chromadb_client) -> None:
|
||||||
|
"""Test add_documents with custom document IDs."""
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{
|
||||||
|
"doc_id": "custom_id_1",
|
||||||
|
"content": "First document",
|
||||||
|
"metadata": {"source": "test1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"doc_id": "custom_id_2",
|
||||||
|
"content": "Second document",
|
||||||
|
"metadata": {"source": "test2"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
client.add_documents(collection_name="test_collection", documents=documents)
|
||||||
|
|
||||||
|
mock_collection.add.assert_called_once_with(
|
||||||
|
ids=["custom_id_1", "custom_id_2"],
|
||||||
|
documents=["First document", "Second document"],
|
||||||
|
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_add_documents_empty_list_raises_error(
|
||||||
|
self, client, mock_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test that add_documents raises error for empty documents list."""
|
||||||
|
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||||
|
client.add_documents(collection_name="test_collection", documents=[])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_documents(
|
||||||
|
self, async_client, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test that aadd_documents adds documents to collection asynchronously."""
|
||||||
|
mock_collection = AsyncMock()
|
||||||
|
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||||
|
return_value=mock_collection
|
||||||
|
)
|
||||||
|
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{
|
||||||
|
"content": "Test document",
|
||||||
|
"metadata": {"source": "test"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
await async_client.aadd_documents(
|
||||||
|
collection_name="test_collection", documents=documents
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
embedding_function=async_client.embedding_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify documents were added to collection
|
||||||
|
mock_collection.add.assert_called_once()
|
||||||
|
call_args = mock_collection.add.call_args
|
||||||
|
assert len(call_args.kwargs["ids"]) == 1
|
||||||
|
assert call_args.kwargs["documents"] == ["Test document"]
|
||||||
|
assert call_args.kwargs["metadatas"] == [{"source": "test"}]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_documents_with_custom_ids(
|
||||||
|
self, async_client, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test aadd_documents with custom document IDs."""
|
||||||
|
mock_collection = AsyncMock()
|
||||||
|
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||||
|
return_value=mock_collection
|
||||||
|
)
|
||||||
|
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{
|
||||||
|
"doc_id": "custom_id_1",
|
||||||
|
"content": "First document",
|
||||||
|
"metadata": {"source": "test1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"doc_id": "custom_id_2",
|
||||||
|
"content": "Second document",
|
||||||
|
"metadata": {"source": "test2"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
await async_client.aadd_documents(
|
||||||
|
collection_name="test_collection", documents=documents
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_collection.add.assert_called_once_with(
|
||||||
|
ids=["custom_id_1", "custom_id_2"],
|
||||||
|
documents=["First document", "Second document"],
|
||||||
|
metadatas=[{"source": "test1"}, {"source": "test2"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_documents_empty_list_raises_error(
|
||||||
|
self, async_client, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test that aadd_documents raises error for empty documents list."""
|
||||||
|
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||||
|
await async_client.aadd_documents(
|
||||||
|
collection_name="test_collection", documents=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_search(self, client, mock_chromadb_client):
|
||||||
|
"""Test that search queries the collection correctly."""
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||||
|
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||||
|
mock_collection.query.return_value = {
|
||||||
|
"ids": [["doc1", "doc2"]],
|
||||||
|
"documents": [["Document 1", "Document 2"]],
|
||||||
|
"metadatas": [[{"source": "test1"}, {"source": "test2"}]],
|
||||||
|
"distances": [[0.1, 0.3]],
|
||||||
|
}
|
||||||
|
|
||||||
|
results = client.search(collection_name="test_collection", query="test query")
|
||||||
|
|
||||||
|
mock_chromadb_client.get_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
embedding_function=client.embedding_function,
|
||||||
|
)
|
||||||
|
mock_collection.query.assert_called_once_with(
|
||||||
|
query_texts=["test query"],
|
||||||
|
n_results=10,
|
||||||
|
where=None,
|
||||||
|
where_document=None,
|
||||||
|
include=["metadatas", "documents", "distances"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0]["id"] == "doc1"
|
||||||
|
assert results[0]["content"] == "Document 1"
|
||||||
|
assert results[0]["metadata"] == {"source": "test1"}
|
||||||
|
assert results[0]["score"] == 0.95
|
||||||
|
|
||||||
|
def test_search_with_optional_params(self, client, mock_chromadb_client):
|
||||||
|
"""Test search with optional parameters."""
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||||
|
mock_chromadb_client.get_collection.return_value = mock_collection
|
||||||
|
mock_collection.query.return_value = {
|
||||||
|
"ids": [["doc1", "doc2", "doc3"]],
|
||||||
|
"documents": [["Document 1", "Document 2", "Document 3"]],
|
||||||
|
"metadatas": [
|
||||||
|
[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}]
|
||||||
|
],
|
||||||
|
"distances": [[0.1, 0.3, 1.5]], # Last one will be filtered by threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
results = client.search(
|
||||||
|
collection_name="test_collection",
|
||||||
|
query="test query",
|
||||||
|
limit=5,
|
||||||
|
metadata_filter={"source": "test"},
|
||||||
|
score_threshold=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_collection.query.assert_called_once_with(
|
||||||
|
query_texts=["test query"],
|
||||||
|
n_results=5,
|
||||||
|
where={"source": "test"},
|
||||||
|
where_document=None,
|
||||||
|
include=["metadatas", "documents", "distances"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch(self, async_client, mock_async_chromadb_client) -> None:
|
||||||
|
"""Test that asearch queries the collection correctly."""
|
||||||
|
mock_collection = AsyncMock()
|
||||||
|
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||||
|
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||||
|
return_value=mock_collection
|
||||||
|
)
|
||||||
|
mock_collection.query = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"ids": [["doc1", "doc2"]],
|
||||||
|
"documents": [["Document 1", "Document 2"]],
|
||||||
|
"metadatas": [[{"source": "test1"}, {"source": "test2"}]],
|
||||||
|
"distances": [[0.1, 0.3]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await async_client.asearch(
|
||||||
|
collection_name="test_collection", query="test query"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_async_chromadb_client.get_collection.assert_called_once_with(
|
||||||
|
name="test_collection",
|
||||||
|
embedding_function=async_client.embedding_function,
|
||||||
|
)
|
||||||
|
mock_collection.query.assert_called_once_with(
|
||||||
|
query_texts=["test query"],
|
||||||
|
n_results=10,
|
||||||
|
where=None,
|
||||||
|
where_document=None,
|
||||||
|
include=["metadatas", "documents", "distances"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0]["id"] == "doc1"
|
||||||
|
assert results[0]["content"] == "Document 1"
|
||||||
|
assert results[0]["metadata"] == {"source": "test1"}
|
||||||
|
assert results[0]["score"] == 0.95
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asearch_with_optional_params(
|
||||||
|
self, async_client, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test asearch with optional parameters."""
|
||||||
|
mock_collection = AsyncMock()
|
||||||
|
mock_collection.metadata = {"hnsw:space": "cosine"}
|
||||||
|
mock_async_chromadb_client.get_collection = AsyncMock(
|
||||||
|
return_value=mock_collection
|
||||||
|
)
|
||||||
|
mock_collection.query = AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"ids": [["doc1", "doc2", "doc3"]],
|
||||||
|
"documents": [["Document 1", "Document 2", "Document 3"]],
|
||||||
|
"metadatas": [
|
||||||
|
[{"source": "test1"}, {"source": "test2"}, {"source": "test3"}]
|
||||||
|
],
|
||||||
|
"distances": [
|
||||||
|
[0.1, 0.3, 1.5]
|
||||||
|
], # Last one will be filtered by threshold
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await async_client.asearch(
|
||||||
|
collection_name="test_collection",
|
||||||
|
query="test query",
|
||||||
|
limit=5,
|
||||||
|
metadata_filter={"source": "test"},
|
||||||
|
score_threshold=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_collection.query.assert_called_once_with(
|
||||||
|
query_texts=["test query"],
|
||||||
|
n_results=5,
|
||||||
|
where={"source": "test"},
|
||||||
|
where_document=None,
|
||||||
|
include=["metadatas", "documents", "distances"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only 2 results should pass the score threshold
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
def test_delete_collection(self, client, mock_chromadb_client):
|
||||||
|
"""Test that delete_collection calls the underlying client correctly."""
|
||||||
|
client.delete_collection(collection_name="test_collection")
|
||||||
|
|
||||||
|
mock_chromadb_client.delete_collection.assert_called_once_with(
|
||||||
|
name="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_adelete_collection(
|
||||||
|
self, async_client, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test that adelete_collection calls the underlying client correctly."""
|
||||||
|
mock_async_chromadb_client.delete_collection = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
await async_client.adelete_collection(collection_name="test_collection")
|
||||||
|
|
||||||
|
mock_async_chromadb_client.delete_collection.assert_called_once_with(
|
||||||
|
name="test_collection"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_reset(self, client, mock_chromadb_client):
|
||||||
|
"""Test that reset calls the underlying client correctly."""
|
||||||
|
mock_chromadb_client.reset.return_value = True
|
||||||
|
|
||||||
|
client.reset()
|
||||||
|
|
||||||
|
mock_chromadb_client.reset.assert_called_once_with()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_areset(self, async_client, mock_async_chromadb_client) -> None:
|
||||||
|
"""Test that areset calls the underlying client correctly."""
|
||||||
|
mock_async_chromadb_client.reset = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
await async_client.areset()
|
||||||
|
|
||||||
|
mock_async_chromadb_client.reset.assert_called_once_with()
|
||||||
Reference in New Issue
Block a user