From 1dbe8aab52b9acbf1647bbc0ecf91e7bd318e5f0 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Wed, 24 Sep 2025 00:05:43 -0400 Subject: [PATCH] fix: add batch_size support to prevent embedder token limit errors - add batch_size field to baseragconfig (default=100) - update chromadb/qdrant clients and factories to use batch_size - extract and filter batch_size from embedder config in knowledgestorage - fix large csv files exceeding embedder token limits (#3574) - remove unneeded conditional for type Co-authored-by: Vini Brasil --- .../knowledge/storage/knowledge_storage.py | 50 ++++- src/crewai/memory/storage/rag_storage.py | 46 +++- src/crewai/rag/chromadb/client.py | 44 ++-- src/crewai/rag/chromadb/factory.py | 1 + src/crewai/rag/chromadb/utils.py | 37 +++- src/crewai/rag/config/base.py | 1 + src/crewai/rag/core/base_client.py | 6 +- src/crewai/rag/embeddings/factory.py | 2 + src/crewai/rag/qdrant/client.py | 59 ++--- src/crewai/rag/qdrant/factory.py | 1 + tests/rag/chromadb/test_client.py | 160 ++++++++++++++ tests/rag/chromadb/test_utils.py | 207 ++++++++++++++++++ 12 files changed, 558 insertions(+), 56 deletions(-) diff --git a/src/crewai/knowledge/storage/knowledge_storage.py b/src/crewai/knowledge/storage/knowledge_storage.py index 3eb70946f..a526ec98b 100644 --- a/src/crewai/knowledge/storage/knowledge_storage.py +++ b/src/crewai/knowledge/storage/knowledge_storage.py @@ -8,7 +8,7 @@ from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper from crewai.rag.config.utils import get_rag_client from crewai.rag.core.base_client import BaseClient -from crewai.rag.embeddings.factory import get_embedding_function +from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function from crewai.rag.factory import create_client from crewai.rag.types import BaseRecord, SearchResult from crewai.utilities.logger import Logger @@ -27,6 +27,7 @@ class KnowledgeStorage(BaseKnowledgeStorage): ) -> None: self.collection_name = collection_name self._client: BaseClient | None = None + self._embedder_config = embedder # Store embedder config warnings.filterwarnings( "ignore", @@ -35,12 +36,29 @@ class KnowledgeStorage(BaseKnowledgeStorage): ) if embedder: - embedding_function = get_embedding_function(embedder) - config = ChromaDBConfig( - embedding_function=cast( - ChromaEmbeddingFunctionWrapper, embedding_function + # Cast to EmbedderConfig for type checking + embedder_typed = cast(EmbedderConfig, embedder) + embedding_function = get_embedding_function(embedder_typed) + batch_size = None + if isinstance(embedder, dict) and "config" in embedder: + nested_config = embedder["config"] + if isinstance(nested_config, dict): + batch_size = nested_config.get("batch_size") + + # Create config with batch_size if provided + if batch_size is not None: + config = ChromaDBConfig( + embedding_function=cast( + ChromaEmbeddingFunctionWrapper, embedding_function + ), + batch_size=batch_size, + ) + else: + config = ChromaDBConfig( + embedding_function=cast( + ChromaEmbeddingFunctionWrapper, embedding_function + ) ) - ) self._client = create_client(config) def _get_client(self) -> BaseClient: @@ -105,9 +123,23 @@ class KnowledgeStorage(BaseKnowledgeStorage): rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents] - client.add_documents( - collection_name=collection_name, documents=rag_documents - ) + batch_size = None + if self._embedder_config and isinstance(self._embedder_config, dict): + if "config" in self._embedder_config: + nested_config = self._embedder_config["config"] + if isinstance(nested_config, dict): + batch_size = nested_config.get("batch_size") + + if batch_size is not None: + client.add_documents( + collection_name=collection_name, + documents=rag_documents, + batch_size=batch_size, + ) + else: + client.add_documents( + collection_name=collection_name, documents=rag_documents + ) except Exception as e: if "dimension mismatch" in str(e).lower(): Logger(verbose=True).log( diff --git a/src/crewai/memory/storage/rag_storage.py b/src/crewai/memory/storage/rag_storage.py index 4f6526c59..a0e08d4dc 100644 --- a/src/crewai/memory/storage/rag_storage.py +++ b/src/crewai/memory/storage/rag_storage.py @@ -66,11 +66,28 @@ class RAGStorage(BaseRAGStorage): f"Error: {e}" ) from e - config = ChromaDBConfig( - embedding_function=cast( - ChromaEmbeddingFunctionWrapper, embedding_function + batch_size = None + if ( + isinstance(self.embedder_config, dict) + and "config" in self.embedder_config + ): + nested_config = self.embedder_config["config"] + if isinstance(nested_config, dict): + batch_size = nested_config.get("batch_size") + + if batch_size is not None: + config = ChromaDBConfig( + embedding_function=cast( + ChromaEmbeddingFunctionWrapper, embedding_function + ), + batch_size=batch_size, + ) + else: + config = ChromaDBConfig( + embedding_function=cast( + ChromaEmbeddingFunctionWrapper, embedding_function + ) ) - ) self._client = create_client(config) def _get_client(self) -> BaseClient: @@ -111,7 +128,26 @@ class RAGStorage(BaseRAGStorage): if metadata: document["metadata"] = metadata - client.add_documents(collection_name=collection_name, documents=[document]) + batch_size = None + if ( + self.embedder_config + and isinstance(self.embedder_config, dict) + and "config" in self.embedder_config + ): + nested_config = self.embedder_config["config"] + if isinstance(nested_config, dict): + batch_size = nested_config.get("batch_size") + + if batch_size is not None: + client.add_documents( + collection_name=collection_name, + documents=[document], + batch_size=batch_size, + ) + else: + client.add_documents( + collection_name=collection_name, documents=[document] + ) except Exception as e: logging.error( f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}" diff --git a/src/crewai/rag/chromadb/client.py b/src/crewai/rag/chromadb/client.py index 0caa4f39c..53a4189dd 100644 --- a/src/crewai/rag/chromadb/client.py +++ b/src/crewai/rag/chromadb/client.py @@ -17,6 +17,7 @@ from crewai.rag.chromadb.types import ( ChromaDBCollectionSearchParams, ) from crewai.rag.chromadb.utils import ( + _create_batch_slice, _extract_search_params, _is_async_client, _is_sync_client, @@ -52,6 +53,7 @@ class ChromaDBClient(BaseClient): embedding_function: ChromaEmbeddingFunction, default_limit: int = 5, default_score_threshold: float = 0.6, + default_batch_size: int = 100, ) -> None: """Initialize ChromaDBClient with client and embedding function. @@ -60,11 +62,13 @@ class ChromaDBClient(BaseClient): embedding_function: Embedding function for text to vector conversion. default_limit: Default number of results to return in searches. default_score_threshold: Default minimum score for search results. + default_batch_size: Default batch size for adding documents. """ self.client = client self.embedding_function = embedding_function self.default_limit = default_limit self.default_score_threshold = default_score_threshold + self.default_batch_size = default_batch_size def create_collection( self, **kwargs: Unpack[ChromaDBCollectionCreateParams] @@ -291,6 +295,7 @@ class ChromaDBClient(BaseClient): - content: The text content (required) - doc_id: Optional unique identifier (auto-generated if missing) - metadata: Optional metadata dictionary + batch_size: Optional batch size for processing documents (default: 100) Raises: TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations. @@ -305,6 +310,7 @@ class ChromaDBClient(BaseClient): collection_name = kwargs["collection_name"] documents = kwargs["documents"] + batch_size = kwargs.get("batch_size", self.default_batch_size) if not documents: raise ValueError("Documents list cannot be empty") @@ -315,13 +321,17 @@ class ChromaDBClient(BaseClient): ) prepared = _prepare_documents_for_chromadb(documents) - # ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty - metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None - collection.upsert( - ids=prepared.ids, - documents=prepared.texts, - metadatas=metadatas, - ) + + for i in range(0, len(prepared.ids), batch_size): + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared=prepared, start_index=i, batch_size=batch_size + ) + + collection.upsert( + ids=batch_ids, + documents=batch_texts, + metadatas=batch_metadatas, + ) async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: """Add documents with their embeddings to a collection asynchronously. @@ -335,6 +345,7 @@ class ChromaDBClient(BaseClient): - content: The text content (required) - doc_id: Optional unique identifier (auto-generated if missing) - metadata: Optional metadata dictionary + batch_size: Optional batch size for processing documents (default: 100) Raises: TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations. @@ -349,6 +360,7 @@ class ChromaDBClient(BaseClient): collection_name = kwargs["collection_name"] documents = kwargs["documents"] + batch_size = kwargs.get("batch_size", self.default_batch_size) if not documents: raise ValueError("Documents list cannot be empty") @@ -358,13 +370,17 @@ class ChromaDBClient(BaseClient): embedding_function=self.embedding_function, ) prepared = _prepare_documents_for_chromadb(documents) - # ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty - metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None - await collection.upsert( - ids=prepared.ids, - documents=prepared.texts, - metadatas=metadatas, - ) + + for i in range(0, len(prepared.ids), batch_size): + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared=prepared, start_index=i, batch_size=batch_size + ) + + await collection.upsert( + ids=batch_ids, + documents=batch_texts, + metadatas=batch_metadatas, + ) def search( self, **kwargs: Unpack[ChromaDBCollectionSearchParams] diff --git a/src/crewai/rag/chromadb/factory.py b/src/crewai/rag/chromadb/factory.py index a02d350ac..7c9532390 100644 --- a/src/crewai/rag/chromadb/factory.py +++ b/src/crewai/rag/chromadb/factory.py @@ -41,4 +41,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient: embedding_function=config.embedding_function, default_limit=config.limit, default_score_threshold=config.score_threshold, + default_batch_size=config.batch_size, ) diff --git a/src/crewai/rag/chromadb/utils.py b/src/crewai/rag/chromadb/utils.py index 3a6a6369c..978725628 100644 --- a/src/crewai/rag/chromadb/utils.py +++ b/src/crewai/rag/chromadb/utils.py @@ -1,6 +1,7 @@ """Utility functions for ChromaDB client implementation.""" import hashlib +import json from collections.abc import Mapping from typing import Literal, TypeGuard, cast @@ -72,7 +73,15 @@ def _prepare_documents_for_chromadb( if "doc_id" in doc: ids.append(doc["doc_id"]) else: - content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16] + content_for_hash = doc["content"] + metadata = doc.get("metadata") + if metadata: + metadata_str = json.dumps(metadata, sort_keys=True) + content_for_hash = f"{content_for_hash}|{metadata_str}" + + content_hash = hashlib.blake2b( + content_for_hash.encode(), digest_size=32 + ).hexdigest() ids.append(content_hash) texts.append(doc["content"]) @@ -88,6 +97,32 @@ def _prepare_documents_for_chromadb( return PreparedDocuments(ids, texts, metadatas) +def _create_batch_slice( + prepared: PreparedDocuments, start_index: int, batch_size: int +) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]: + """Create a batch slice from prepared documents. + + Args: + prepared: PreparedDocuments containing ids, texts, and metadatas. + start_index: Starting index for the batch. + batch_size: Size of the batch. + + Returns: + Tuple of (batch_ids, batch_texts, batch_metadatas). + """ + batch_end = min(start_index + batch_size, len(prepared.ids)) + batch_ids = prepared.ids[start_index:batch_end] + batch_texts = prepared.texts[start_index:batch_end] + batch_metadatas = ( + prepared.metadatas[start_index:batch_end] if prepared.metadatas else None + ) + + if batch_metadatas and not any(m for m in batch_metadatas): + batch_metadatas = None + + return batch_ids, batch_texts, batch_metadatas + + def _extract_search_params( kwargs: ChromaDBCollectionSearchParams, ) -> ExtractedSearchParams: diff --git a/src/crewai/rag/config/base.py b/src/crewai/rag/config/base.py index 411c4f7bc..107f644cf 100644 --- a/src/crewai/rag/config/base.py +++ b/src/crewai/rag/config/base.py @@ -16,3 +16,4 @@ class BaseRagConfig: embedding_function: Any | None = field(default=None) limit: int = field(default=5) score_threshold: float = field(default=0.6) + batch_size: int = field(default=100) diff --git a/src/crewai/rag/core/base_client.py b/src/crewai/rag/core/base_client.py index f526d2faa..bd7bd5d08 100644 --- a/src/crewai/rag/core/base_client.py +++ b/src/crewai/rag/core/base_client.py @@ -29,7 +29,7 @@ class BaseCollectionParams(TypedDict): ] -class BaseCollectionAddParams(BaseCollectionParams): +class BaseCollectionAddParams(BaseCollectionParams, total=False): """Parameters for adding documents to a collection. Extends BaseCollectionParams with document-specific fields. @@ -37,9 +37,11 @@ class BaseCollectionAddParams(BaseCollectionParams): Attributes: collection_name: The name of the collection to add documents to. documents: List of BaseRecord dictionaries containing document data. + batch_size: Optional batch size for processing documents to avoid token limits. """ - documents: list[BaseRecord] + documents: Required[list[BaseRecord]] + batch_size: int class BaseCollectionSearchParams(BaseCollectionParams, total=False): diff --git a/src/crewai/rag/embeddings/factory.py b/src/crewai/rag/embeddings/factory.py index 3ced72655..cc756f314 100644 --- a/src/crewai/rag/embeddings/factory.py +++ b/src/crewai/rag/embeddings/factory.py @@ -244,4 +244,6 @@ def get_embedding_function( _inject_api_key_from_env(provider, config_dict) + config_dict.pop("batch_size", None) + return EMBEDDING_PROVIDERS[provider](**config_dict) diff --git a/src/crewai/rag/qdrant/client.py b/src/crewai/rag/qdrant/client.py index c82ad9f8e..8e889544a 100644 --- a/src/crewai/rag/qdrant/client.py +++ b/src/crewai/rag/qdrant/client.py @@ -48,6 +48,7 @@ class QdrantClient(BaseClient): embedding_function: EmbeddingFunction | AsyncEmbeddingFunction, default_limit: int = 5, default_score_threshold: float = 0.6, + default_batch_size: int = 100, ) -> None: """Initialize QdrantClient with client and embedding function. @@ -56,11 +57,13 @@ class QdrantClient(BaseClient): embedding_function: Embedding function for text to vector conversion. default_limit: Default number of results to return in searches. default_score_threshold: Default minimum score for search results. + default_batch_size: Default batch size for adding documents. """ self.client = client self.embedding_function = embedding_function self.default_limit = default_limit self.default_score_threshold = default_score_threshold + self.default_batch_size = default_batch_size def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None: """Create a new collection in Qdrant. @@ -234,6 +237,7 @@ class QdrantClient(BaseClient): Keyword Args: collection_name: The name of the collection to add documents to. documents: List of BaseRecord dicts containing document data. + batch_size: Optional batch size for processing documents (default: 100) Raises: ValueError: If collection doesn't exist or documents list is empty. @@ -249,6 +253,7 @@ class QdrantClient(BaseClient): collection_name = kwargs["collection_name"] documents = kwargs["documents"] + batch_size = kwargs.get("batch_size", self.default_batch_size) if not documents: raise ValueError("Documents list cannot be empty") @@ -256,19 +261,20 @@ class QdrantClient(BaseClient): if not self.client.collection_exists(collection_name): raise ValueError(f"Collection '{collection_name}' does not exist") - points = [] - for doc in documents: - if _is_async_embedding_function(self.embedding_function): - raise TypeError( - "Async embedding function cannot be used with sync add_documents. " - "Use aadd_documents instead." - ) - sync_fn = cast(EmbeddingFunction, self.embedding_function) - embedding = sync_fn(doc["content"]) - point = _create_point_from_document(doc, embedding) - points.append(point) - - self.client.upsert(collection_name=collection_name, points=points) + for i in range(0, len(documents), batch_size): + batch_docs = documents[i : min(i + batch_size, len(documents))] + points = [] + for doc in batch_docs: + if _is_async_embedding_function(self.embedding_function): + raise TypeError( + "Async embedding function cannot be used with sync add_documents. " + "Use aadd_documents instead." + ) + sync_fn = cast(EmbeddingFunction, self.embedding_function) + embedding = sync_fn(doc["content"]) + point = _create_point_from_document(doc, embedding) + points.append(point) + self.client.upsert(collection_name=collection_name, points=points) async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None: """Add documents with their embeddings to a collection asynchronously. @@ -276,6 +282,7 @@ class QdrantClient(BaseClient): Keyword Args: collection_name: The name of the collection to add documents to. documents: List of BaseRecord dicts containing document data. + batch_size: Optional batch size for processing documents (default: 100) Raises: ValueError: If collection doesn't exist or documents list is empty. @@ -291,6 +298,7 @@ class QdrantClient(BaseClient): collection_name = kwargs["collection_name"] documents = kwargs["documents"] + batch_size = kwargs.get("batch_size", self.default_batch_size) if not documents: raise ValueError("Documents list cannot be empty") @@ -298,18 +306,19 @@ class QdrantClient(BaseClient): if not await self.client.collection_exists(collection_name): raise ValueError(f"Collection '{collection_name}' does not exist") - points = [] - for doc in documents: - if _is_async_embedding_function(self.embedding_function): - async_fn = cast(AsyncEmbeddingFunction, self.embedding_function) - embedding = await async_fn(doc["content"]) - else: - sync_fn = cast(EmbeddingFunction, self.embedding_function) - embedding = sync_fn(doc["content"]) - point = _create_point_from_document(doc, embedding) - points.append(point) - - await self.client.upsert(collection_name=collection_name, points=points) + for i in range(0, len(documents), batch_size): + batch_docs = documents[i : min(i + batch_size, len(documents))] + points = [] + for doc in batch_docs: + if _is_async_embedding_function(self.embedding_function): + async_fn = cast(AsyncEmbeddingFunction, self.embedding_function) + embedding = await async_fn(doc["content"]) + else: + sync_fn = cast(EmbeddingFunction, self.embedding_function) + embedding = sync_fn(doc["content"]) + point = _create_point_from_document(doc, embedding) + points.append(point) + await self.client.upsert(collection_name=collection_name, points=points) def search( self, **kwargs: Unpack[BaseCollectionSearchParams] diff --git a/src/crewai/rag/qdrant/factory.py b/src/crewai/rag/qdrant/factory.py index 512e7a562..d0692dc26 100644 --- a/src/crewai/rag/qdrant/factory.py +++ b/src/crewai/rag/qdrant/factory.py @@ -22,4 +22,5 @@ def create_client(config: QdrantConfig) -> QdrantClient: embedding_function=config.embedding_function, default_limit=config.limit, default_score_threshold=config.score_threshold, + default_batch_size=config.batch_size, ) diff --git a/tests/rag/chromadb/test_client.py b/tests/rag/chromadb/test_client.py index 8fef2ff8d..ab31549e7 100644 --- a/tests/rag/chromadb/test_client.py +++ b/tests/rag/chromadb/test_client.py @@ -34,6 +34,30 @@ def client(mock_chromadb_client) -> ChromaDBClient: return client +@pytest.fixture +def client_with_batch_size(mock_chromadb_client) -> ChromaDBClient: + """Create a ChromaDBClient instance with custom batch size for testing.""" + mock_embedding = Mock() + client = ChromaDBClient( + client=mock_chromadb_client, + embedding_function=mock_embedding, + default_batch_size=2, + ) + return client + + +@pytest.fixture +def async_client_with_batch_size(mock_async_chromadb_client) -> ChromaDBClient: + """Create a ChromaDBClient instance with async client and custom batch size for testing.""" + mock_embedding = Mock() + client = ChromaDBClient( + client=mock_async_chromadb_client, + embedding_function=mock_embedding, + default_batch_size=2, + ) + return client + + @pytest.fixture def async_client(mock_async_chromadb_client) -> ChromaDBClient: """Create a ChromaDBClient instance with async client for testing.""" @@ -612,3 +636,139 @@ class TestChromaDBClient: await async_client.areset() mock_async_chromadb_client.reset.assert_called_once_with() + + def test_add_documents_with_batch_size( + self, client_with_batch_size, mock_chromadb_client + ) -> None: + """Test add_documents with batch size splits documents into batches.""" + mock_collection = Mock() + mock_chromadb_client.get_or_create_collection.return_value = mock_collection + + documents: list[BaseRecord] = [ + {"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}}, + {"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}}, + {"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}}, + {"doc_id": "id4", "content": "Document 4", "metadata": {"source": "test4"}}, + {"doc_id": "id5", "content": "Document 5", "metadata": {"source": "test5"}}, + ] + + client_with_batch_size.add_documents( + collection_name="test_collection", documents=documents + ) + + assert mock_collection.upsert.call_count == 3 + + first_call = mock_collection.upsert.call_args_list[0] + assert first_call.kwargs["ids"] == ["id1", "id2"] + assert first_call.kwargs["documents"] == ["Document 1", "Document 2"] + assert first_call.kwargs["metadatas"] == [ + {"source": "test1"}, + {"source": "test2"}, + ] + + second_call = mock_collection.upsert.call_args_list[1] + assert second_call.kwargs["ids"] == ["id3", "id4"] + assert second_call.kwargs["documents"] == ["Document 3", "Document 4"] + assert second_call.kwargs["metadatas"] == [ + {"source": "test3"}, + {"source": "test4"}, + ] + + third_call = mock_collection.upsert.call_args_list[2] + assert third_call.kwargs["ids"] == ["id5"] + assert third_call.kwargs["documents"] == ["Document 5"] + assert third_call.kwargs["metadatas"] == [{"source": "test5"}] + + def test_add_documents_with_explicit_batch_size( + self, client, mock_chromadb_client + ) -> None: + """Test add_documents with explicitly provided batch size.""" + mock_collection = Mock() + mock_chromadb_client.get_or_create_collection.return_value = mock_collection + + documents: list[BaseRecord] = [ + {"doc_id": "id1", "content": "Document 1"}, + {"doc_id": "id2", "content": "Document 2"}, + {"doc_id": "id3", "content": "Document 3"}, + ] + + client.add_documents( + collection_name="test_collection", documents=documents, batch_size=1 + ) + + assert mock_collection.upsert.call_count == 3 + for i, call in enumerate(mock_collection.upsert.call_args_list): + assert len(call.kwargs["ids"]) == 1 + assert call.kwargs["ids"] == [f"id{i + 1}"] + + @pytest.mark.asyncio + async def test_aadd_documents_with_batch_size( + self, async_client_with_batch_size, mock_async_chromadb_client + ) -> None: + """Test aadd_documents with batch size splits documents into batches.""" + mock_collection = AsyncMock() + mock_async_chromadb_client.get_or_create_collection = AsyncMock( + return_value=mock_collection + ) + + documents: list[BaseRecord] = [ + {"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}}, + {"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}}, + {"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}}, + ] + + await async_client_with_batch_size.aadd_documents( + collection_name="test_collection", documents=documents + ) + + assert mock_collection.upsert.call_count == 2 + + first_call = mock_collection.upsert.call_args_list[0] + assert first_call.kwargs["ids"] == ["id1", "id2"] + assert first_call.kwargs["documents"] == ["Document 1", "Document 2"] + + second_call = mock_collection.upsert.call_args_list[1] + assert second_call.kwargs["ids"] == ["id3"] + assert second_call.kwargs["documents"] == ["Document 3"] + + @pytest.mark.asyncio + async def test_aadd_documents_with_explicit_batch_size( + self, async_client, mock_async_chromadb_client + ) -> None: + """Test aadd_documents with explicitly provided batch size.""" + mock_collection = AsyncMock() + mock_async_chromadb_client.get_or_create_collection = AsyncMock( + return_value=mock_collection + ) + + documents: list[BaseRecord] = [ + {"doc_id": "id1", "content": "Document 1"}, + {"doc_id": "id2", "content": "Document 2"}, + {"doc_id": "id3", "content": "Document 3"}, + {"doc_id": "id4", "content": "Document 4"}, + ] + + await async_client.aadd_documents( + collection_name="test_collection", documents=documents, batch_size=3 + ) + + assert mock_collection.upsert.call_count == 2 + + first_call = mock_collection.upsert.call_args_list[0] + assert len(first_call.kwargs["ids"]) == 3 + + second_call = mock_collection.upsert.call_args_list[1] + assert len(second_call.kwargs["ids"]) == 1 + + def test_client_default_batch_size_initialization(self) -> None: + """Test that client initializes with correct default batch size.""" + mock_client = Mock() + mock_embedding = Mock() + + client = ChromaDBClient(client=mock_client, embedding_function=mock_embedding) + assert client.default_batch_size == 100 + + custom_client = ChromaDBClient( + client=mock_client, embedding_function=mock_embedding, default_batch_size=50 + ) + assert custom_client.default_batch_size == 50 diff --git a/tests/rag/chromadb/test_utils.py b/tests/rag/chromadb/test_utils.py index ac7a8f5a9..9bede2ee9 100644 --- a/tests/rag/chromadb/test_utils.py +++ b/tests/rag/chromadb/test_utils.py @@ -1,11 +1,15 @@ """Tests for ChromaDB utility functions.""" +from crewai.rag.chromadb.types import PreparedDocuments from crewai.rag.chromadb.utils import ( MAX_COLLECTION_LENGTH, MIN_COLLECTION_LENGTH, + _create_batch_slice, _is_ipv4_pattern, + _prepare_documents_for_chromadb, _sanitize_collection_name, ) +from crewai.rag.types import BaseRecord class TestChromaDBUtils: @@ -93,3 +97,206 @@ class TestChromaDBUtils: assert len(sanitized) >= MIN_COLLECTION_LENGTH assert sanitized[0].isalnum() assert sanitized[-1].isalnum() + + +class TestPrepareDocumentsForChromaDB: + """Test suite for _prepare_documents_for_chromadb function.""" + + def test_prepare_documents_with_doc_ids(self) -> None: + """Test preparing documents that already have doc_ids.""" + documents: list[BaseRecord] = [ + { + "doc_id": "id1", + "content": "First document", + "metadata": {"source": "test1"}, + }, + { + "doc_id": "id2", + "content": "Second document", + "metadata": {"source": "test2"}, + }, + ] + + result = _prepare_documents_for_chromadb(documents) + + assert result.ids == ["id1", "id2"] + assert result.texts == ["First document", "Second document"] + assert result.metadatas == [{"source": "test1"}, {"source": "test2"}] + + def test_prepare_documents_generate_ids(self) -> None: + """Test preparing documents without doc_ids (should generate hashes).""" + documents: list[BaseRecord] = [ + {"content": "Test content", "metadata": {"key": "value"}}, + {"content": "Another test"}, + ] + + result = _prepare_documents_for_chromadb(documents) + + assert len(result.ids) == 2 + assert all(len(doc_id) == 64 for doc_id in result.ids) + assert result.texts == ["Test content", "Another test"] + assert result.metadatas == [{"key": "value"}, {}] + + def test_prepare_documents_with_list_metadata(self) -> None: + """Test preparing documents with list metadata (should take first item).""" + documents: list[BaseRecord] = [ + {"content": "Test", "metadata": [{"first": "item"}, {"second": "item"}]}, + {"content": "Test2", "metadata": []}, + ] + + result = _prepare_documents_for_chromadb(documents) + + assert result.metadatas == [{"first": "item"}, {}] + + def test_prepare_documents_no_metadata(self) -> None: + """Test preparing documents without metadata.""" + documents: list[BaseRecord] = [ + {"content": "Document 1"}, + {"content": "Document 2", "metadata": None}, + ] + + result = _prepare_documents_for_chromadb(documents) + + assert result.metadatas == [{}, {}] + + def test_prepare_documents_hash_consistency(self) -> None: + """Test that identical content produces identical hashes.""" + documents1: list[BaseRecord] = [ + {"content": "Same content", "metadata": {"key": "value"}} + ] + documents2: list[BaseRecord] = [ + {"content": "Same content", "metadata": {"key": "value"}} + ] + + result1 = _prepare_documents_for_chromadb(documents1) + result2 = _prepare_documents_for_chromadb(documents2) + + assert result1.ids == result2.ids + + +class TestCreateBatchSlice: + """Test suite for _create_batch_slice function.""" + + def test_create_batch_slice_normal(self) -> None: + """Test creating a normal batch slice.""" + prepared = PreparedDocuments( + ids=["id1", "id2", "id3", "id4", "id5"], + texts=["doc1", "doc2", "doc3", "doc4", "doc5"], + metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}, {"e": 5}], + ) + + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared, start_index=1, batch_size=3 + ) + + assert batch_ids == ["id2", "id3", "id4"] + assert batch_texts == ["doc2", "doc3", "doc4"] + assert batch_metadatas == [{"b": 2}, {"c": 3}, {"d": 4}] + + def test_create_batch_slice_at_end(self) -> None: + """Test creating a batch slice that goes beyond the end.""" + prepared = PreparedDocuments( + ids=["id1", "id2", "id3"], + texts=["doc1", "doc2", "doc3"], + metadatas=[{"a": 1}, {"b": 2}, {"c": 3}], + ) + + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared, start_index=2, batch_size=5 + ) + + assert batch_ids == ["id3"] + assert batch_texts == ["doc3"] + assert batch_metadatas == [{"c": 3}] + + def test_create_batch_slice_empty_batch(self) -> None: + """Test creating a batch slice starting beyond the data.""" + prepared = PreparedDocuments( + ids=["id1", "id2"], texts=["doc1", "doc2"], metadatas=[{"a": 1}, {"b": 2}] + ) + + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared, start_index=5, batch_size=3 + ) + + assert batch_ids == [] + assert batch_texts == [] + assert batch_metadatas == [] + + def test_create_batch_slice_no_metadatas(self) -> None: + """Test creating a batch slice with no metadatas.""" + prepared = PreparedDocuments( + ids=["id1", "id2", "id3"], texts=["doc1", "doc2", "doc3"], metadatas=[] + ) + + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared, start_index=0, batch_size=2 + ) + + assert batch_ids == ["id1", "id2"] + assert batch_texts == ["doc1", "doc2"] + assert batch_metadatas is None + + def test_create_batch_slice_all_empty_metadatas(self) -> None: + """Test creating a batch slice where all metadatas are empty.""" + prepared = PreparedDocuments( + ids=["id1", "id2", "id3"], + texts=["doc1", "doc2", "doc3"], + metadatas=[{}, {}, {}], + ) + + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared, start_index=0, batch_size=3 + ) + + assert batch_ids == ["id1", "id2", "id3"] + assert batch_texts == ["doc1", "doc2", "doc3"] + assert batch_metadatas is None + + def test_create_batch_slice_some_empty_metadatas(self) -> None: + """Test creating a batch slice where some metadatas are empty.""" + prepared = PreparedDocuments( + ids=["id1", "id2", "id3"], + texts=["doc1", "doc2", "doc3"], + metadatas=[{"a": 1}, {}, {"c": 3}], + ) + + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared, start_index=0, batch_size=3 + ) + + assert batch_ids == ["id1", "id2", "id3"] + assert batch_texts == ["doc1", "doc2", "doc3"] + assert batch_metadatas == [{"a": 1}, {}, {"c": 3}] + + def test_create_batch_slice_zero_start_index(self) -> None: + """Test creating a batch slice starting from index 0.""" + prepared = PreparedDocuments( + ids=["id1", "id2", "id3", "id4"], + texts=["doc1", "doc2", "doc3", "doc4"], + metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}], + ) + + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared, start_index=0, batch_size=2 + ) + + assert batch_ids == ["id1", "id2"] + assert batch_texts == ["doc1", "doc2"] + assert batch_metadatas == [{"a": 1}, {"b": 2}] + + def test_create_batch_slice_single_item(self) -> None: + """Test creating a batch slice with batch size 1.""" + prepared = PreparedDocuments( + ids=["id1", "id2", "id3"], + texts=["doc1", "doc2", "doc3"], + metadatas=[{"a": 1}, {"b": 2}, {"c": 3}], + ) + + batch_ids, batch_texts, batch_metadatas = _create_batch_slice( + prepared, start_index=1, batch_size=1 + ) + + assert batch_ids == ["id2"] + assert batch_texts == ["doc2"] + assert batch_metadatas == [{"b": 2}]