mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 15:52:34 +00:00
fix: add batch_size support to prevent embedder token limit errors
- add batch_size field to baseragconfig (default=100) - update chromadb/qdrant clients and factories to use batch_size - extract and filter batch_size from embedder config in knowledgestorage - fix large csv files exceeding embedder token limits (#3574) - remove unneeded conditional for type Co-authored-by: Vini Brasil <vini@hey.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from crewai.rag.chromadb.types import (
|
||||
ChromaDBCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.chromadb.utils import (
|
||||
_create_batch_slice,
|
||||
_extract_search_params,
|
||||
_is_async_client,
|
||||
_is_sync_client,
|
||||
@@ -52,6 +53,7 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
@@ -60,11 +62,13 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
@@ -291,6 +295,7 @@ class ChromaDBClient(BaseClient):
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
@@ -305,6 +310,7 @@ class ChromaDBClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -315,13 +321,17 @@ class ChromaDBClient(BaseClient):
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas,
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
@@ -335,6 +345,7 @@ class ChromaDBClient(BaseClient):
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
@@ -349,6 +360,7 @@ class ChromaDBClient(BaseClient):
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
@@ -358,13 +370,17 @@ class ChromaDBClient(BaseClient):
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
||||
await collection.upsert(
|
||||
ids=prepared.ids,
|
||||
documents=prepared.texts,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas,
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
|
||||
@@ -41,4 +41,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Utility functions for ChromaDB client implementation."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Literal, TypeGuard, cast
|
||||
|
||||
@@ -72,7 +73,15 @@ def _prepare_documents_for_chromadb(
|
||||
if "doc_id" in doc:
|
||||
ids.append(doc["doc_id"])
|
||||
else:
|
||||
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
|
||||
content_for_hash = doc["content"]
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
metadata_str = json.dumps(metadata, sort_keys=True)
|
||||
content_for_hash = f"{content_for_hash}|{metadata_str}"
|
||||
|
||||
content_hash = hashlib.blake2b(
|
||||
content_for_hash.encode(), digest_size=32
|
||||
).hexdigest()
|
||||
ids.append(content_hash)
|
||||
|
||||
texts.append(doc["content"])
|
||||
@@ -88,6 +97,32 @@ def _prepare_documents_for_chromadb(
|
||||
return PreparedDocuments(ids, texts, metadatas)
|
||||
|
||||
|
||||
def _create_batch_slice(
|
||||
prepared: PreparedDocuments, start_index: int, batch_size: int
|
||||
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
|
||||
"""Create a batch slice from prepared documents.
|
||||
|
||||
Args:
|
||||
prepared: PreparedDocuments containing ids, texts, and metadatas.
|
||||
start_index: Starting index for the batch.
|
||||
batch_size: Size of the batch.
|
||||
|
||||
Returns:
|
||||
Tuple of (batch_ids, batch_texts, batch_metadatas).
|
||||
"""
|
||||
batch_end = min(start_index + batch_size, len(prepared.ids))
|
||||
batch_ids = prepared.ids[start_index:batch_end]
|
||||
batch_texts = prepared.texts[start_index:batch_end]
|
||||
batch_metadatas = (
|
||||
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
|
||||
)
|
||||
|
||||
if batch_metadatas and not any(m for m in batch_metadatas):
|
||||
batch_metadatas = None
|
||||
|
||||
return batch_ids, batch_texts, batch_metadatas
|
||||
|
||||
|
||||
def _extract_search_params(
|
||||
kwargs: ChromaDBCollectionSearchParams,
|
||||
) -> ExtractedSearchParams:
|
||||
|
||||
Reference in New Issue
Block a user