fix: add batch_size support to prevent embedder token limit errors

- add batch_size field to baseragconfig (default=100)  
- update chromadb/qdrant clients and factories to use batch_size  
- extract and filter batch_size from embedder config in knowledgestorage  
- fix large csv files exceeding embedder token limits (#3574)  
- remove unneeded conditional for type  

Co-authored-by: Vini Brasil <vini@hey.com>
This commit is contained in:
Greyson LaLonde
2025-09-24 00:05:43 -04:00
committed by GitHub
parent 4ac65eb0a6
commit 1dbe8aab52
12 changed files with 558 additions and 56 deletions

View File

@@ -8,7 +8,7 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import get_embedding_function from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
from crewai.rag.factory import create_client from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult from crewai.rag.types import BaseRecord, SearchResult
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
@@ -27,6 +27,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
) -> None: ) -> None:
self.collection_name = collection_name self.collection_name = collection_name
self._client: BaseClient | None = None self._client: BaseClient | None = None
self._embedder_config = embedder # Store embedder config
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
@@ -35,7 +36,24 @@ class KnowledgeStorage(BaseKnowledgeStorage):
) )
if embedder: if embedder:
embedding_function = get_embedding_function(embedder) # Cast to EmbedderConfig for type checking
embedder_typed = cast(EmbedderConfig, embedder)
embedding_function = get_embedding_function(embedder_typed)
batch_size = None
if isinstance(embedder, dict) and "config" in embedder:
nested_config = embedder["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
# Create config with batch_size if provided
if batch_size is not None:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
),
batch_size=batch_size,
)
else:
config = ChromaDBConfig( config = ChromaDBConfig(
embedding_function=cast( embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function ChromaEmbeddingFunctionWrapper, embedding_function
@@ -105,6 +123,20 @@ class KnowledgeStorage(BaseKnowledgeStorage):
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents] rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
batch_size = None
if self._embedder_config and isinstance(self._embedder_config, dict):
if "config" in self._embedder_config:
nested_config = self._embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
client.add_documents(
collection_name=collection_name,
documents=rag_documents,
batch_size=batch_size,
)
else:
client.add_documents( client.add_documents(
collection_name=collection_name, documents=rag_documents collection_name=collection_name, documents=rag_documents
) )

View File

@@ -66,6 +66,23 @@ class RAGStorage(BaseRAGStorage):
f"Error: {e}" f"Error: {e}"
) from e ) from e
batch_size = None
if (
isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
),
batch_size=batch_size,
)
else:
config = ChromaDBConfig( config = ChromaDBConfig(
embedding_function=cast( embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function ChromaEmbeddingFunctionWrapper, embedding_function
@@ -111,7 +128,26 @@ class RAGStorage(BaseRAGStorage):
if metadata: if metadata:
document["metadata"] = metadata document["metadata"] = metadata
client.add_documents(collection_name=collection_name, documents=[document]) batch_size = None
if (
self.embedder_config
and isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
client.add_documents(
collection_name=collection_name,
documents=[document],
batch_size=batch_size,
)
else:
client.add_documents(
collection_name=collection_name, documents=[document]
)
except Exception as e: except Exception as e:
logging.error( logging.error(
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}" f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"

View File

@@ -17,6 +17,7 @@ from crewai.rag.chromadb.types import (
ChromaDBCollectionSearchParams, ChromaDBCollectionSearchParams,
) )
from crewai.rag.chromadb.utils import ( from crewai.rag.chromadb.utils import (
_create_batch_slice,
_extract_search_params, _extract_search_params,
_is_async_client, _is_async_client,
_is_sync_client, _is_sync_client,
@@ -52,6 +53,7 @@ class ChromaDBClient(BaseClient):
embedding_function: ChromaEmbeddingFunction, embedding_function: ChromaEmbeddingFunction,
default_limit: int = 5, default_limit: int = 5,
default_score_threshold: float = 0.6, default_score_threshold: float = 0.6,
default_batch_size: int = 100,
) -> None: ) -> None:
"""Initialize ChromaDBClient with client and embedding function. """Initialize ChromaDBClient with client and embedding function.
@@ -60,11 +62,13 @@ class ChromaDBClient(BaseClient):
embedding_function: Embedding function for text to vector conversion. embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches. default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results. default_score_threshold: Default minimum score for search results.
default_batch_size: Default batch size for adding documents.
""" """
self.client = client self.client = client
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.default_limit = default_limit self.default_limit = default_limit
self.default_score_threshold = default_score_threshold self.default_score_threshold = default_score_threshold
self.default_batch_size = default_batch_size
def create_collection( def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams] self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
@@ -291,6 +295,7 @@ class ChromaDBClient(BaseClient):
- content: The text content (required) - content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing) - doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary - metadata: Optional metadata dictionary
batch_size: Optional batch size for processing documents (default: 100)
Raises: Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations. TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
@@ -305,6 +310,7 @@ class ChromaDBClient(BaseClient):
collection_name = kwargs["collection_name"] collection_name = kwargs["collection_name"]
documents = kwargs["documents"] documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents: if not documents:
raise ValueError("Documents list cannot be empty") raise ValueError("Documents list cannot be empty")
@@ -315,12 +321,16 @@ class ChromaDBClient(BaseClient):
) )
prepared = _prepare_documents_for_chromadb(documents) prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
collection.upsert( collection.upsert(
ids=prepared.ids, ids=batch_ids,
documents=prepared.texts, documents=batch_texts,
metadatas=metadatas, metadatas=batch_metadatas,
) )
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
@@ -335,6 +345,7 @@ class ChromaDBClient(BaseClient):
- content: The text content (required) - content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing) - doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary - metadata: Optional metadata dictionary
batch_size: Optional batch size for processing documents (default: 100)
Raises: Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations. TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
@@ -349,6 +360,7 @@ class ChromaDBClient(BaseClient):
collection_name = kwargs["collection_name"] collection_name = kwargs["collection_name"]
documents = kwargs["documents"] documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents: if not documents:
raise ValueError("Documents list cannot be empty") raise ValueError("Documents list cannot be empty")
@@ -358,12 +370,16 @@ class ChromaDBClient(BaseClient):
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
) )
prepared = _prepare_documents_for_chromadb(documents) prepared = _prepare_documents_for_chromadb(documents)
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
await collection.upsert( await collection.upsert(
ids=prepared.ids, ids=batch_ids,
documents=prepared.texts, documents=batch_texts,
metadatas=metadatas, metadatas=batch_metadatas,
) )
def search( def search(

View File

@@ -41,4 +41,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
embedding_function=config.embedding_function, embedding_function=config.embedding_function,
default_limit=config.limit, default_limit=config.limit,
default_score_threshold=config.score_threshold, default_score_threshold=config.score_threshold,
default_batch_size=config.batch_size,
) )

View File

@@ -1,6 +1,7 @@
"""Utility functions for ChromaDB client implementation.""" """Utility functions for ChromaDB client implementation."""
import hashlib import hashlib
import json
from collections.abc import Mapping from collections.abc import Mapping
from typing import Literal, TypeGuard, cast from typing import Literal, TypeGuard, cast
@@ -72,7 +73,15 @@ def _prepare_documents_for_chromadb(
if "doc_id" in doc: if "doc_id" in doc:
ids.append(doc["doc_id"]) ids.append(doc["doc_id"])
else: else:
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16] content_for_hash = doc["content"]
metadata = doc.get("metadata")
if metadata:
metadata_str = json.dumps(metadata, sort_keys=True)
content_for_hash = f"{content_for_hash}|{metadata_str}"
content_hash = hashlib.blake2b(
content_for_hash.encode(), digest_size=32
).hexdigest()
ids.append(content_hash) ids.append(content_hash)
texts.append(doc["content"]) texts.append(doc["content"])
@@ -88,6 +97,32 @@ def _prepare_documents_for_chromadb(
return PreparedDocuments(ids, texts, metadatas) return PreparedDocuments(ids, texts, metadatas)
def _create_batch_slice(
prepared: PreparedDocuments, start_index: int, batch_size: int
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
"""Create a batch slice from prepared documents.
Args:
prepared: PreparedDocuments containing ids, texts, and metadatas.
start_index: Starting index for the batch.
batch_size: Size of the batch.
Returns:
Tuple of (batch_ids, batch_texts, batch_metadatas).
"""
batch_end = min(start_index + batch_size, len(prepared.ids))
batch_ids = prepared.ids[start_index:batch_end]
batch_texts = prepared.texts[start_index:batch_end]
batch_metadatas = (
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
)
if batch_metadatas and not any(m for m in batch_metadatas):
batch_metadatas = None
return batch_ids, batch_texts, batch_metadatas
def _extract_search_params( def _extract_search_params(
kwargs: ChromaDBCollectionSearchParams, kwargs: ChromaDBCollectionSearchParams,
) -> ExtractedSearchParams: ) -> ExtractedSearchParams:

View File

@@ -16,3 +16,4 @@ class BaseRagConfig:
embedding_function: Any | None = field(default=None) embedding_function: Any | None = field(default=None)
limit: int = field(default=5) limit: int = field(default=5)
score_threshold: float = field(default=0.6) score_threshold: float = field(default=0.6)
batch_size: int = field(default=100)

View File

@@ -29,7 +29,7 @@ class BaseCollectionParams(TypedDict):
] ]
class BaseCollectionAddParams(BaseCollectionParams): class BaseCollectionAddParams(BaseCollectionParams, total=False):
"""Parameters for adding documents to a collection. """Parameters for adding documents to a collection.
Extends BaseCollectionParams with document-specific fields. Extends BaseCollectionParams with document-specific fields.
@@ -37,9 +37,11 @@ class BaseCollectionAddParams(BaseCollectionParams):
Attributes: Attributes:
collection_name: The name of the collection to add documents to. collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dictionaries containing document data. documents: List of BaseRecord dictionaries containing document data.
batch_size: Optional batch size for processing documents to avoid token limits.
""" """
documents: list[BaseRecord] documents: Required[list[BaseRecord]]
batch_size: int
class BaseCollectionSearchParams(BaseCollectionParams, total=False): class BaseCollectionSearchParams(BaseCollectionParams, total=False):

View File

@@ -244,4 +244,6 @@ def get_embedding_function(
_inject_api_key_from_env(provider, config_dict) _inject_api_key_from_env(provider, config_dict)
config_dict.pop("batch_size", None)
return EMBEDDING_PROVIDERS[provider](**config_dict) return EMBEDDING_PROVIDERS[provider](**config_dict)

View File

@@ -48,6 +48,7 @@ class QdrantClient(BaseClient):
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction, embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
default_limit: int = 5, default_limit: int = 5,
default_score_threshold: float = 0.6, default_score_threshold: float = 0.6,
default_batch_size: int = 100,
) -> None: ) -> None:
"""Initialize QdrantClient with client and embedding function. """Initialize QdrantClient with client and embedding function.
@@ -56,11 +57,13 @@ class QdrantClient(BaseClient):
embedding_function: Embedding function for text to vector conversion. embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches. default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results. default_score_threshold: Default minimum score for search results.
default_batch_size: Default batch size for adding documents.
""" """
self.client = client self.client = client
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.default_limit = default_limit self.default_limit = default_limit
self.default_score_threshold = default_score_threshold self.default_score_threshold = default_score_threshold
self.default_batch_size = default_batch_size
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None: def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant. """Create a new collection in Qdrant.
@@ -234,6 +237,7 @@ class QdrantClient(BaseClient):
Keyword Args: Keyword Args:
collection_name: The name of the collection to add documents to. collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing document data. documents: List of BaseRecord dicts containing document data.
batch_size: Optional batch size for processing documents (default: 100)
Raises: Raises:
ValueError: If collection doesn't exist or documents list is empty. ValueError: If collection doesn't exist or documents list is empty.
@@ -249,6 +253,7 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"] collection_name = kwargs["collection_name"]
documents = kwargs["documents"] documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents: if not documents:
raise ValueError("Documents list cannot be empty") raise ValueError("Documents list cannot be empty")
@@ -256,8 +261,10 @@ class QdrantClient(BaseClient):
if not self.client.collection_exists(collection_name): if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist") raise ValueError(f"Collection '{collection_name}' does not exist")
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : min(i + batch_size, len(documents))]
points = [] points = []
for doc in documents: for doc in batch_docs:
if _is_async_embedding_function(self.embedding_function): if _is_async_embedding_function(self.embedding_function):
raise TypeError( raise TypeError(
"Async embedding function cannot be used with sync add_documents. " "Async embedding function cannot be used with sync add_documents. "
@@ -267,7 +274,6 @@ class QdrantClient(BaseClient):
embedding = sync_fn(doc["content"]) embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding) point = _create_point_from_document(doc, embedding)
points.append(point) points.append(point)
self.client.upsert(collection_name=collection_name, points=points) self.client.upsert(collection_name=collection_name, points=points)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
@@ -276,6 +282,7 @@ class QdrantClient(BaseClient):
Keyword Args: Keyword Args:
collection_name: The name of the collection to add documents to. collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing document data. documents: List of BaseRecord dicts containing document data.
batch_size: Optional batch size for processing documents (default: 100)
Raises: Raises:
ValueError: If collection doesn't exist or documents list is empty. ValueError: If collection doesn't exist or documents list is empty.
@@ -291,6 +298,7 @@ class QdrantClient(BaseClient):
collection_name = kwargs["collection_name"] collection_name = kwargs["collection_name"]
documents = kwargs["documents"] documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents: if not documents:
raise ValueError("Documents list cannot be empty") raise ValueError("Documents list cannot be empty")
@@ -298,8 +306,10 @@ class QdrantClient(BaseClient):
if not await self.client.collection_exists(collection_name): if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist") raise ValueError(f"Collection '{collection_name}' does not exist")
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : min(i + batch_size, len(documents))]
points = [] points = []
for doc in documents: for doc in batch_docs:
if _is_async_embedding_function(self.embedding_function): if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function) async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"]) embedding = await async_fn(doc["content"])
@@ -308,7 +318,6 @@ class QdrantClient(BaseClient):
embedding = sync_fn(doc["content"]) embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding) point = _create_point_from_document(doc, embedding)
points.append(point) points.append(point)
await self.client.upsert(collection_name=collection_name, points=points) await self.client.upsert(collection_name=collection_name, points=points)
def search( def search(

View File

@@ -22,4 +22,5 @@ def create_client(config: QdrantConfig) -> QdrantClient:
embedding_function=config.embedding_function, embedding_function=config.embedding_function,
default_limit=config.limit, default_limit=config.limit,
default_score_threshold=config.score_threshold, default_score_threshold=config.score_threshold,
default_batch_size=config.batch_size,
) )

View File

@@ -34,6 +34,30 @@ def client(mock_chromadb_client) -> ChromaDBClient:
return client return client
@pytest.fixture
def client_with_batch_size(mock_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with custom batch size for testing."""
mock_embedding = Mock()
client = ChromaDBClient(
client=mock_chromadb_client,
embedding_function=mock_embedding,
default_batch_size=2,
)
return client
@pytest.fixture
def async_client_with_batch_size(mock_async_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with async client and custom batch size for testing."""
mock_embedding = Mock()
client = ChromaDBClient(
client=mock_async_chromadb_client,
embedding_function=mock_embedding,
default_batch_size=2,
)
return client
@pytest.fixture @pytest.fixture
def async_client(mock_async_chromadb_client) -> ChromaDBClient: def async_client(mock_async_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with async client for testing.""" """Create a ChromaDBClient instance with async client for testing."""
@@ -612,3 +636,139 @@ class TestChromaDBClient:
await async_client.areset() await async_client.areset()
mock_async_chromadb_client.reset.assert_called_once_with() mock_async_chromadb_client.reset.assert_called_once_with()
def test_add_documents_with_batch_size(
self, client_with_batch_size, mock_chromadb_client
) -> None:
"""Test add_documents with batch size splits documents into batches."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
{"doc_id": "id4", "content": "Document 4", "metadata": {"source": "test4"}},
{"doc_id": "id5", "content": "Document 5", "metadata": {"source": "test5"}},
]
client_with_batch_size.add_documents(
collection_name="test_collection", documents=documents
)
assert mock_collection.upsert.call_count == 3
first_call = mock_collection.upsert.call_args_list[0]
assert first_call.kwargs["ids"] == ["id1", "id2"]
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
assert first_call.kwargs["metadatas"] == [
{"source": "test1"},
{"source": "test2"},
]
second_call = mock_collection.upsert.call_args_list[1]
assert second_call.kwargs["ids"] == ["id3", "id4"]
assert second_call.kwargs["documents"] == ["Document 3", "Document 4"]
assert second_call.kwargs["metadatas"] == [
{"source": "test3"},
{"source": "test4"},
]
third_call = mock_collection.upsert.call_args_list[2]
assert third_call.kwargs["ids"] == ["id5"]
assert third_call.kwargs["documents"] == ["Document 5"]
assert third_call.kwargs["metadatas"] == [{"source": "test5"}]
def test_add_documents_with_explicit_batch_size(
self, client, mock_chromadb_client
) -> None:
"""Test add_documents with explicitly provided batch size."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1"},
{"doc_id": "id2", "content": "Document 2"},
{"doc_id": "id3", "content": "Document 3"},
]
client.add_documents(
collection_name="test_collection", documents=documents, batch_size=1
)
assert mock_collection.upsert.call_count == 3
for i, call in enumerate(mock_collection.upsert.call_args_list):
assert len(call.kwargs["ids"]) == 1
assert call.kwargs["ids"] == [f"id{i + 1}"]
@pytest.mark.asyncio
async def test_aadd_documents_with_batch_size(
self, async_client_with_batch_size, mock_async_chromadb_client
) -> None:
"""Test aadd_documents with batch size splits documents into batches."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
]
await async_client_with_batch_size.aadd_documents(
collection_name="test_collection", documents=documents
)
assert mock_collection.upsert.call_count == 2
first_call = mock_collection.upsert.call_args_list[0]
assert first_call.kwargs["ids"] == ["id1", "id2"]
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
second_call = mock_collection.upsert.call_args_list[1]
assert second_call.kwargs["ids"] == ["id3"]
assert second_call.kwargs["documents"] == ["Document 3"]
@pytest.mark.asyncio
async def test_aadd_documents_with_explicit_batch_size(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test aadd_documents with explicitly provided batch size."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1"},
{"doc_id": "id2", "content": "Document 2"},
{"doc_id": "id3", "content": "Document 3"},
{"doc_id": "id4", "content": "Document 4"},
]
await async_client.aadd_documents(
collection_name="test_collection", documents=documents, batch_size=3
)
assert mock_collection.upsert.call_count == 2
first_call = mock_collection.upsert.call_args_list[0]
assert len(first_call.kwargs["ids"]) == 3
second_call = mock_collection.upsert.call_args_list[1]
assert len(second_call.kwargs["ids"]) == 1
def test_client_default_batch_size_initialization(self) -> None:
"""Test that client initializes with correct default batch size."""
mock_client = Mock()
mock_embedding = Mock()
client = ChromaDBClient(client=mock_client, embedding_function=mock_embedding)
assert client.default_batch_size == 100
custom_client = ChromaDBClient(
client=mock_client, embedding_function=mock_embedding, default_batch_size=50
)
assert custom_client.default_batch_size == 50

View File

@@ -1,11 +1,15 @@
"""Tests for ChromaDB utility functions.""" """Tests for ChromaDB utility functions."""
from crewai.rag.chromadb.types import PreparedDocuments
from crewai.rag.chromadb.utils import ( from crewai.rag.chromadb.utils import (
MAX_COLLECTION_LENGTH, MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH, MIN_COLLECTION_LENGTH,
_create_batch_slice,
_is_ipv4_pattern, _is_ipv4_pattern,
_prepare_documents_for_chromadb,
_sanitize_collection_name, _sanitize_collection_name,
) )
from crewai.rag.types import BaseRecord
class TestChromaDBUtils: class TestChromaDBUtils:
@@ -93,3 +97,206 @@ class TestChromaDBUtils:
assert len(sanitized) >= MIN_COLLECTION_LENGTH assert len(sanitized) >= MIN_COLLECTION_LENGTH
assert sanitized[0].isalnum() assert sanitized[0].isalnum()
assert sanitized[-1].isalnum() assert sanitized[-1].isalnum()
class TestPrepareDocumentsForChromaDB:
"""Test suite for _prepare_documents_for_chromadb function."""
def test_prepare_documents_with_doc_ids(self) -> None:
"""Test preparing documents that already have doc_ids."""
documents: list[BaseRecord] = [
{
"doc_id": "id1",
"content": "First document",
"metadata": {"source": "test1"},
},
{
"doc_id": "id2",
"content": "Second document",
"metadata": {"source": "test2"},
},
]
result = _prepare_documents_for_chromadb(documents)
assert result.ids == ["id1", "id2"]
assert result.texts == ["First document", "Second document"]
assert result.metadatas == [{"source": "test1"}, {"source": "test2"}]
def test_prepare_documents_generate_ids(self) -> None:
"""Test preparing documents without doc_ids (should generate hashes)."""
documents: list[BaseRecord] = [
{"content": "Test content", "metadata": {"key": "value"}},
{"content": "Another test"},
]
result = _prepare_documents_for_chromadb(documents)
assert len(result.ids) == 2
assert all(len(doc_id) == 64 for doc_id in result.ids)
assert result.texts == ["Test content", "Another test"]
assert result.metadatas == [{"key": "value"}, {}]
def test_prepare_documents_with_list_metadata(self) -> None:
"""Test preparing documents with list metadata (should take first item)."""
documents: list[BaseRecord] = [
{"content": "Test", "metadata": [{"first": "item"}, {"second": "item"}]},
{"content": "Test2", "metadata": []},
]
result = _prepare_documents_for_chromadb(documents)
assert result.metadatas == [{"first": "item"}, {}]
def test_prepare_documents_no_metadata(self) -> None:
"""Test preparing documents without metadata."""
documents: list[BaseRecord] = [
{"content": "Document 1"},
{"content": "Document 2", "metadata": None},
]
result = _prepare_documents_for_chromadb(documents)
assert result.metadatas == [{}, {}]
def test_prepare_documents_hash_consistency(self) -> None:
"""Test that identical content produces identical hashes."""
documents1: list[BaseRecord] = [
{"content": "Same content", "metadata": {"key": "value"}}
]
documents2: list[BaseRecord] = [
{"content": "Same content", "metadata": {"key": "value"}}
]
result1 = _prepare_documents_for_chromadb(documents1)
result2 = _prepare_documents_for_chromadb(documents2)
assert result1.ids == result2.ids
class TestCreateBatchSlice:
"""Test suite for _create_batch_slice function."""
def test_create_batch_slice_normal(self) -> None:
"""Test creating a normal batch slice."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3", "id4", "id5"],
texts=["doc1", "doc2", "doc3", "doc4", "doc5"],
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}, {"e": 5}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=1, batch_size=3
)
assert batch_ids == ["id2", "id3", "id4"]
assert batch_texts == ["doc2", "doc3", "doc4"]
assert batch_metadatas == [{"b": 2}, {"c": 3}, {"d": 4}]
def test_create_batch_slice_at_end(self) -> None:
"""Test creating a batch slice that goes beyond the end."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3"],
texts=["doc1", "doc2", "doc3"],
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=2, batch_size=5
)
assert batch_ids == ["id3"]
assert batch_texts == ["doc3"]
assert batch_metadatas == [{"c": 3}]
def test_create_batch_slice_empty_batch(self) -> None:
"""Test creating a batch slice starting beyond the data."""
prepared = PreparedDocuments(
ids=["id1", "id2"], texts=["doc1", "doc2"], metadatas=[{"a": 1}, {"b": 2}]
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=5, batch_size=3
)
assert batch_ids == []
assert batch_texts == []
assert batch_metadatas == []
def test_create_batch_slice_no_metadatas(self) -> None:
"""Test creating a batch slice with no metadatas."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3"], texts=["doc1", "doc2", "doc3"], metadatas=[]
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=0, batch_size=2
)
assert batch_ids == ["id1", "id2"]
assert batch_texts == ["doc1", "doc2"]
assert batch_metadatas is None
def test_create_batch_slice_all_empty_metadatas(self) -> None:
"""Test creating a batch slice where all metadatas are empty."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3"],
texts=["doc1", "doc2", "doc3"],
metadatas=[{}, {}, {}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=0, batch_size=3
)
assert batch_ids == ["id1", "id2", "id3"]
assert batch_texts == ["doc1", "doc2", "doc3"]
assert batch_metadatas is None
def test_create_batch_slice_some_empty_metadatas(self) -> None:
"""Test creating a batch slice where some metadatas are empty."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3"],
texts=["doc1", "doc2", "doc3"],
metadatas=[{"a": 1}, {}, {"c": 3}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=0, batch_size=3
)
assert batch_ids == ["id1", "id2", "id3"]
assert batch_texts == ["doc1", "doc2", "doc3"]
assert batch_metadatas == [{"a": 1}, {}, {"c": 3}]
def test_create_batch_slice_zero_start_index(self) -> None:
"""Test creating a batch slice starting from index 0."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3", "id4"],
texts=["doc1", "doc2", "doc3", "doc4"],
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=0, batch_size=2
)
assert batch_ids == ["id1", "id2"]
assert batch_texts == ["doc1", "doc2"]
assert batch_metadatas == [{"a": 1}, {"b": 2}]
def test_create_batch_slice_single_item(self) -> None:
"""Test creating a batch slice with batch size 1."""
prepared = PreparedDocuments(
ids=["id1", "id2", "id3"],
texts=["doc1", "doc2", "doc3"],
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
)
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared, start_index=1, batch_size=1
)
assert batch_ids == ["id2"]
assert batch_texts == ["doc2"]
assert batch_metadatas == [{"b": 2}]