mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
fix: add batch_size support to prevent embedder token limit errors
- add batch_size field to baseragconfig (default=100) - update chromadb/qdrant clients and factories to use batch_size - extract and filter batch_size from embedder config in knowledgestorage - fix large csv files exceeding embedder token limits (#3574) - remove unneeded conditional for type Co-authored-by: Vini Brasil <vini@hey.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
|||||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||||
from crewai.rag.config.utils import get_rag_client
|
from crewai.rag.config.utils import get_rag_client
|
||||||
from crewai.rag.core.base_client import BaseClient
|
from crewai.rag.core.base_client import BaseClient
|
||||||
from crewai.rag.embeddings.factory import get_embedding_function
|
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||||
from crewai.rag.factory import create_client
|
from crewai.rag.factory import create_client
|
||||||
from crewai.rag.types import BaseRecord, SearchResult
|
from crewai.rag.types import BaseRecord, SearchResult
|
||||||
from crewai.utilities.logger import Logger
|
from crewai.utilities.logger import Logger
|
||||||
@@ -27,6 +27,7 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
self._client: BaseClient | None = None
|
self._client: BaseClient | None = None
|
||||||
|
self._embedder_config = embedder # Store embedder config
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
"ignore",
|
"ignore",
|
||||||
@@ -35,7 +36,24 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if embedder:
|
if embedder:
|
||||||
embedding_function = get_embedding_function(embedder)
|
# Cast to EmbedderConfig for type checking
|
||||||
|
embedder_typed = cast(EmbedderConfig, embedder)
|
||||||
|
embedding_function = get_embedding_function(embedder_typed)
|
||||||
|
batch_size = None
|
||||||
|
if isinstance(embedder, dict) and "config" in embedder:
|
||||||
|
nested_config = embedder["config"]
|
||||||
|
if isinstance(nested_config, dict):
|
||||||
|
batch_size = nested_config.get("batch_size")
|
||||||
|
|
||||||
|
# Create config with batch_size if provided
|
||||||
|
if batch_size is not None:
|
||||||
|
config = ChromaDBConfig(
|
||||||
|
embedding_function=cast(
|
||||||
|
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||||
|
),
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
config = ChromaDBConfig(
|
config = ChromaDBConfig(
|
||||||
embedding_function=cast(
|
embedding_function=cast(
|
||||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||||
@@ -105,6 +123,20 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
|||||||
|
|
||||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||||
|
|
||||||
|
batch_size = None
|
||||||
|
if self._embedder_config and isinstance(self._embedder_config, dict):
|
||||||
|
if "config" in self._embedder_config:
|
||||||
|
nested_config = self._embedder_config["config"]
|
||||||
|
if isinstance(nested_config, dict):
|
||||||
|
batch_size = nested_config.get("batch_size")
|
||||||
|
|
||||||
|
if batch_size is not None:
|
||||||
|
client.add_documents(
|
||||||
|
collection_name=collection_name,
|
||||||
|
documents=rag_documents,
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
client.add_documents(
|
client.add_documents(
|
||||||
collection_name=collection_name, documents=rag_documents
|
collection_name=collection_name, documents=rag_documents
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -66,6 +66,23 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
f"Error: {e}"
|
f"Error: {e}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
batch_size = None
|
||||||
|
if (
|
||||||
|
isinstance(self.embedder_config, dict)
|
||||||
|
and "config" in self.embedder_config
|
||||||
|
):
|
||||||
|
nested_config = self.embedder_config["config"]
|
||||||
|
if isinstance(nested_config, dict):
|
||||||
|
batch_size = nested_config.get("batch_size")
|
||||||
|
|
||||||
|
if batch_size is not None:
|
||||||
|
config = ChromaDBConfig(
|
||||||
|
embedding_function=cast(
|
||||||
|
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||||
|
),
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
config = ChromaDBConfig(
|
config = ChromaDBConfig(
|
||||||
embedding_function=cast(
|
embedding_function=cast(
|
||||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||||
@@ -111,7 +128,26 @@ class RAGStorage(BaseRAGStorage):
|
|||||||
if metadata:
|
if metadata:
|
||||||
document["metadata"] = metadata
|
document["metadata"] = metadata
|
||||||
|
|
||||||
client.add_documents(collection_name=collection_name, documents=[document])
|
batch_size = None
|
||||||
|
if (
|
||||||
|
self.embedder_config
|
||||||
|
and isinstance(self.embedder_config, dict)
|
||||||
|
and "config" in self.embedder_config
|
||||||
|
):
|
||||||
|
nested_config = self.embedder_config["config"]
|
||||||
|
if isinstance(nested_config, dict):
|
||||||
|
batch_size = nested_config.get("batch_size")
|
||||||
|
|
||||||
|
if batch_size is not None:
|
||||||
|
client.add_documents(
|
||||||
|
collection_name=collection_name,
|
||||||
|
documents=[document],
|
||||||
|
batch_size=batch_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
client.add_documents(
|
||||||
|
collection_name=collection_name, documents=[document]
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(
|
logging.error(
|
||||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from crewai.rag.chromadb.types import (
|
|||||||
ChromaDBCollectionSearchParams,
|
ChromaDBCollectionSearchParams,
|
||||||
)
|
)
|
||||||
from crewai.rag.chromadb.utils import (
|
from crewai.rag.chromadb.utils import (
|
||||||
|
_create_batch_slice,
|
||||||
_extract_search_params,
|
_extract_search_params,
|
||||||
_is_async_client,
|
_is_async_client,
|
||||||
_is_sync_client,
|
_is_sync_client,
|
||||||
@@ -52,6 +53,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
embedding_function: ChromaEmbeddingFunction,
|
embedding_function: ChromaEmbeddingFunction,
|
||||||
default_limit: int = 5,
|
default_limit: int = 5,
|
||||||
default_score_threshold: float = 0.6,
|
default_score_threshold: float = 0.6,
|
||||||
|
default_batch_size: int = 100,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize ChromaDBClient with client and embedding function.
|
"""Initialize ChromaDBClient with client and embedding function.
|
||||||
|
|
||||||
@@ -60,11 +62,13 @@ class ChromaDBClient(BaseClient):
|
|||||||
embedding_function: Embedding function for text to vector conversion.
|
embedding_function: Embedding function for text to vector conversion.
|
||||||
default_limit: Default number of results to return in searches.
|
default_limit: Default number of results to return in searches.
|
||||||
default_score_threshold: Default minimum score for search results.
|
default_score_threshold: Default minimum score for search results.
|
||||||
|
default_batch_size: Default batch size for adding documents.
|
||||||
"""
|
"""
|
||||||
self.client = client
|
self.client = client
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.default_limit = default_limit
|
self.default_limit = default_limit
|
||||||
self.default_score_threshold = default_score_threshold
|
self.default_score_threshold = default_score_threshold
|
||||||
|
self.default_batch_size = default_batch_size
|
||||||
|
|
||||||
def create_collection(
|
def create_collection(
|
||||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||||
@@ -291,6 +295,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
- content: The text content (required)
|
- content: The text content (required)
|
||||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||||
- metadata: Optional metadata dictionary
|
- metadata: Optional metadata dictionary
|
||||||
|
batch_size: Optional batch size for processing documents (default: 100)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||||
@@ -305,6 +310,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
|
|
||||||
collection_name = kwargs["collection_name"]
|
collection_name = kwargs["collection_name"]
|
||||||
documents = kwargs["documents"]
|
documents = kwargs["documents"]
|
||||||
|
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
raise ValueError("Documents list cannot be empty")
|
raise ValueError("Documents list cannot be empty")
|
||||||
@@ -315,12 +321,16 @@ class ChromaDBClient(BaseClient):
|
|||||||
)
|
)
|
||||||
|
|
||||||
prepared = _prepare_documents_for_chromadb(documents)
|
prepared = _prepare_documents_for_chromadb(documents)
|
||||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
|
||||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
for i in range(0, len(prepared.ids), batch_size):
|
||||||
|
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||||
|
prepared=prepared, start_index=i, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
collection.upsert(
|
collection.upsert(
|
||||||
ids=prepared.ids,
|
ids=batch_ids,
|
||||||
documents=prepared.texts,
|
documents=batch_texts,
|
||||||
metadatas=metadatas,
|
metadatas=batch_metadatas,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||||
@@ -335,6 +345,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
- content: The text content (required)
|
- content: The text content (required)
|
||||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||||
- metadata: Optional metadata dictionary
|
- metadata: Optional metadata dictionary
|
||||||
|
batch_size: Optional batch size for processing documents (default: 100)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||||
@@ -349,6 +360,7 @@ class ChromaDBClient(BaseClient):
|
|||||||
|
|
||||||
collection_name = kwargs["collection_name"]
|
collection_name = kwargs["collection_name"]
|
||||||
documents = kwargs["documents"]
|
documents = kwargs["documents"]
|
||||||
|
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
raise ValueError("Documents list cannot be empty")
|
raise ValueError("Documents list cannot be empty")
|
||||||
@@ -358,12 +370,16 @@ class ChromaDBClient(BaseClient):
|
|||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
)
|
)
|
||||||
prepared = _prepare_documents_for_chromadb(documents)
|
prepared = _prepare_documents_for_chromadb(documents)
|
||||||
# ChromaDB doesn't accept empty metadata dicts, so pass None if all are empty
|
|
||||||
metadatas = prepared.metadatas if any(m for m in prepared.metadatas) else None
|
for i in range(0, len(prepared.ids), batch_size):
|
||||||
|
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||||
|
prepared=prepared, start_index=i, batch_size=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
await collection.upsert(
|
await collection.upsert(
|
||||||
ids=prepared.ids,
|
ids=batch_ids,
|
||||||
documents=prepared.texts,
|
documents=batch_texts,
|
||||||
metadatas=metadatas,
|
metadatas=batch_metadatas,
|
||||||
)
|
)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
|
|||||||
@@ -41,4 +41,5 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
|||||||
embedding_function=config.embedding_function,
|
embedding_function=config.embedding_function,
|
||||||
default_limit=config.limit,
|
default_limit=config.limit,
|
||||||
default_score_threshold=config.score_threshold,
|
default_score_threshold=config.score_threshold,
|
||||||
|
default_batch_size=config.batch_size,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Utility functions for ChromaDB client implementation."""
|
"""Utility functions for ChromaDB client implementation."""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Literal, TypeGuard, cast
|
from typing import Literal, TypeGuard, cast
|
||||||
|
|
||||||
@@ -72,7 +73,15 @@ def _prepare_documents_for_chromadb(
|
|||||||
if "doc_id" in doc:
|
if "doc_id" in doc:
|
||||||
ids.append(doc["doc_id"])
|
ids.append(doc["doc_id"])
|
||||||
else:
|
else:
|
||||||
content_hash = hashlib.sha256(doc["content"].encode()).hexdigest()[:16]
|
content_for_hash = doc["content"]
|
||||||
|
metadata = doc.get("metadata")
|
||||||
|
if metadata:
|
||||||
|
metadata_str = json.dumps(metadata, sort_keys=True)
|
||||||
|
content_for_hash = f"{content_for_hash}|{metadata_str}"
|
||||||
|
|
||||||
|
content_hash = hashlib.blake2b(
|
||||||
|
content_for_hash.encode(), digest_size=32
|
||||||
|
).hexdigest()
|
||||||
ids.append(content_hash)
|
ids.append(content_hash)
|
||||||
|
|
||||||
texts.append(doc["content"])
|
texts.append(doc["content"])
|
||||||
@@ -88,6 +97,32 @@ def _prepare_documents_for_chromadb(
|
|||||||
return PreparedDocuments(ids, texts, metadatas)
|
return PreparedDocuments(ids, texts, metadatas)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_batch_slice(
|
||||||
|
prepared: PreparedDocuments, start_index: int, batch_size: int
|
||||||
|
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
|
||||||
|
"""Create a batch slice from prepared documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prepared: PreparedDocuments containing ids, texts, and metadatas.
|
||||||
|
start_index: Starting index for the batch.
|
||||||
|
batch_size: Size of the batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (batch_ids, batch_texts, batch_metadatas).
|
||||||
|
"""
|
||||||
|
batch_end = min(start_index + batch_size, len(prepared.ids))
|
||||||
|
batch_ids = prepared.ids[start_index:batch_end]
|
||||||
|
batch_texts = prepared.texts[start_index:batch_end]
|
||||||
|
batch_metadatas = (
|
||||||
|
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch_metadatas and not any(m for m in batch_metadatas):
|
||||||
|
batch_metadatas = None
|
||||||
|
|
||||||
|
return batch_ids, batch_texts, batch_metadatas
|
||||||
|
|
||||||
|
|
||||||
def _extract_search_params(
|
def _extract_search_params(
|
||||||
kwargs: ChromaDBCollectionSearchParams,
|
kwargs: ChromaDBCollectionSearchParams,
|
||||||
) -> ExtractedSearchParams:
|
) -> ExtractedSearchParams:
|
||||||
|
|||||||
@@ -16,3 +16,4 @@ class BaseRagConfig:
|
|||||||
embedding_function: Any | None = field(default=None)
|
embedding_function: Any | None = field(default=None)
|
||||||
limit: int = field(default=5)
|
limit: int = field(default=5)
|
||||||
score_threshold: float = field(default=0.6)
|
score_threshold: float = field(default=0.6)
|
||||||
|
batch_size: int = field(default=100)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class BaseCollectionParams(TypedDict):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class BaseCollectionAddParams(BaseCollectionParams):
|
class BaseCollectionAddParams(BaseCollectionParams, total=False):
|
||||||
"""Parameters for adding documents to a collection.
|
"""Parameters for adding documents to a collection.
|
||||||
|
|
||||||
Extends BaseCollectionParams with document-specific fields.
|
Extends BaseCollectionParams with document-specific fields.
|
||||||
@@ -37,9 +37,11 @@ class BaseCollectionAddParams(BaseCollectionParams):
|
|||||||
Attributes:
|
Attributes:
|
||||||
collection_name: The name of the collection to add documents to.
|
collection_name: The name of the collection to add documents to.
|
||||||
documents: List of BaseRecord dictionaries containing document data.
|
documents: List of BaseRecord dictionaries containing document data.
|
||||||
|
batch_size: Optional batch size for processing documents to avoid token limits.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
documents: list[BaseRecord]
|
documents: Required[list[BaseRecord]]
|
||||||
|
batch_size: int
|
||||||
|
|
||||||
|
|
||||||
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
||||||
|
|||||||
@@ -244,4 +244,6 @@ def get_embedding_function(
|
|||||||
|
|
||||||
_inject_api_key_from_env(provider, config_dict)
|
_inject_api_key_from_env(provider, config_dict)
|
||||||
|
|
||||||
|
config_dict.pop("batch_size", None)
|
||||||
|
|
||||||
return EMBEDDING_PROVIDERS[provider](**config_dict)
|
return EMBEDDING_PROVIDERS[provider](**config_dict)
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class QdrantClient(BaseClient):
|
|||||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||||
default_limit: int = 5,
|
default_limit: int = 5,
|
||||||
default_score_threshold: float = 0.6,
|
default_score_threshold: float = 0.6,
|
||||||
|
default_batch_size: int = 100,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize QdrantClient with client and embedding function.
|
"""Initialize QdrantClient with client and embedding function.
|
||||||
|
|
||||||
@@ -56,11 +57,13 @@ class QdrantClient(BaseClient):
|
|||||||
embedding_function: Embedding function for text to vector conversion.
|
embedding_function: Embedding function for text to vector conversion.
|
||||||
default_limit: Default number of results to return in searches.
|
default_limit: Default number of results to return in searches.
|
||||||
default_score_threshold: Default minimum score for search results.
|
default_score_threshold: Default minimum score for search results.
|
||||||
|
default_batch_size: Default batch size for adding documents.
|
||||||
"""
|
"""
|
||||||
self.client = client
|
self.client = client
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.default_limit = default_limit
|
self.default_limit = default_limit
|
||||||
self.default_score_threshold = default_score_threshold
|
self.default_score_threshold = default_score_threshold
|
||||||
|
self.default_batch_size = default_batch_size
|
||||||
|
|
||||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||||
"""Create a new collection in Qdrant.
|
"""Create a new collection in Qdrant.
|
||||||
@@ -234,6 +237,7 @@ class QdrantClient(BaseClient):
|
|||||||
Keyword Args:
|
Keyword Args:
|
||||||
collection_name: The name of the collection to add documents to.
|
collection_name: The name of the collection to add documents to.
|
||||||
documents: List of BaseRecord dicts containing document data.
|
documents: List of BaseRecord dicts containing document data.
|
||||||
|
batch_size: Optional batch size for processing documents (default: 100)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If collection doesn't exist or documents list is empty.
|
ValueError: If collection doesn't exist or documents list is empty.
|
||||||
@@ -249,6 +253,7 @@ class QdrantClient(BaseClient):
|
|||||||
|
|
||||||
collection_name = kwargs["collection_name"]
|
collection_name = kwargs["collection_name"]
|
||||||
documents = kwargs["documents"]
|
documents = kwargs["documents"]
|
||||||
|
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
raise ValueError("Documents list cannot be empty")
|
raise ValueError("Documents list cannot be empty")
|
||||||
@@ -256,8 +261,10 @@ class QdrantClient(BaseClient):
|
|||||||
if not self.client.collection_exists(collection_name):
|
if not self.client.collection_exists(collection_name):
|
||||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||||
|
|
||||||
|
for i in range(0, len(documents), batch_size):
|
||||||
|
batch_docs = documents[i : min(i + batch_size, len(documents))]
|
||||||
points = []
|
points = []
|
||||||
for doc in documents:
|
for doc in batch_docs:
|
||||||
if _is_async_embedding_function(self.embedding_function):
|
if _is_async_embedding_function(self.embedding_function):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Async embedding function cannot be used with sync add_documents. "
|
"Async embedding function cannot be used with sync add_documents. "
|
||||||
@@ -267,7 +274,6 @@ class QdrantClient(BaseClient):
|
|||||||
embedding = sync_fn(doc["content"])
|
embedding = sync_fn(doc["content"])
|
||||||
point = _create_point_from_document(doc, embedding)
|
point = _create_point_from_document(doc, embedding)
|
||||||
points.append(point)
|
points.append(point)
|
||||||
|
|
||||||
self.client.upsert(collection_name=collection_name, points=points)
|
self.client.upsert(collection_name=collection_name, points=points)
|
||||||
|
|
||||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||||
@@ -276,6 +282,7 @@ class QdrantClient(BaseClient):
|
|||||||
Keyword Args:
|
Keyword Args:
|
||||||
collection_name: The name of the collection to add documents to.
|
collection_name: The name of the collection to add documents to.
|
||||||
documents: List of BaseRecord dicts containing document data.
|
documents: List of BaseRecord dicts containing document data.
|
||||||
|
batch_size: Optional batch size for processing documents (default: 100)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If collection doesn't exist or documents list is empty.
|
ValueError: If collection doesn't exist or documents list is empty.
|
||||||
@@ -291,6 +298,7 @@ class QdrantClient(BaseClient):
|
|||||||
|
|
||||||
collection_name = kwargs["collection_name"]
|
collection_name = kwargs["collection_name"]
|
||||||
documents = kwargs["documents"]
|
documents = kwargs["documents"]
|
||||||
|
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
raise ValueError("Documents list cannot be empty")
|
raise ValueError("Documents list cannot be empty")
|
||||||
@@ -298,8 +306,10 @@ class QdrantClient(BaseClient):
|
|||||||
if not await self.client.collection_exists(collection_name):
|
if not await self.client.collection_exists(collection_name):
|
||||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||||
|
|
||||||
|
for i in range(0, len(documents), batch_size):
|
||||||
|
batch_docs = documents[i : min(i + batch_size, len(documents))]
|
||||||
points = []
|
points = []
|
||||||
for doc in documents:
|
for doc in batch_docs:
|
||||||
if _is_async_embedding_function(self.embedding_function):
|
if _is_async_embedding_function(self.embedding_function):
|
||||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||||
embedding = await async_fn(doc["content"])
|
embedding = await async_fn(doc["content"])
|
||||||
@@ -308,7 +318,6 @@ class QdrantClient(BaseClient):
|
|||||||
embedding = sync_fn(doc["content"])
|
embedding = sync_fn(doc["content"])
|
||||||
point = _create_point_from_document(doc, embedding)
|
point = _create_point_from_document(doc, embedding)
|
||||||
points.append(point)
|
points.append(point)
|
||||||
|
|
||||||
await self.client.upsert(collection_name=collection_name, points=points)
|
await self.client.upsert(collection_name=collection_name, points=points)
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
|
|||||||
@@ -22,4 +22,5 @@ def create_client(config: QdrantConfig) -> QdrantClient:
|
|||||||
embedding_function=config.embedding_function,
|
embedding_function=config.embedding_function,
|
||||||
default_limit=config.limit,
|
default_limit=config.limit,
|
||||||
default_score_threshold=config.score_threshold,
|
default_score_threshold=config.score_threshold,
|
||||||
|
default_batch_size=config.batch_size,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -34,6 +34,30 @@ def client(mock_chromadb_client) -> ChromaDBClient:
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client_with_batch_size(mock_chromadb_client) -> ChromaDBClient:
|
||||||
|
"""Create a ChromaDBClient instance with custom batch size for testing."""
|
||||||
|
mock_embedding = Mock()
|
||||||
|
client = ChromaDBClient(
|
||||||
|
client=mock_chromadb_client,
|
||||||
|
embedding_function=mock_embedding,
|
||||||
|
default_batch_size=2,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def async_client_with_batch_size(mock_async_chromadb_client) -> ChromaDBClient:
|
||||||
|
"""Create a ChromaDBClient instance with async client and custom batch size for testing."""
|
||||||
|
mock_embedding = Mock()
|
||||||
|
client = ChromaDBClient(
|
||||||
|
client=mock_async_chromadb_client,
|
||||||
|
embedding_function=mock_embedding,
|
||||||
|
default_batch_size=2,
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
|
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
|
||||||
"""Create a ChromaDBClient instance with async client for testing."""
|
"""Create a ChromaDBClient instance with async client for testing."""
|
||||||
@@ -612,3 +636,139 @@ class TestChromaDBClient:
|
|||||||
await async_client.areset()
|
await async_client.areset()
|
||||||
|
|
||||||
mock_async_chromadb_client.reset.assert_called_once_with()
|
mock_async_chromadb_client.reset.assert_called_once_with()
|
||||||
|
|
||||||
|
def test_add_documents_with_batch_size(
|
||||||
|
self, client_with_batch_size, mock_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test add_documents with batch size splits documents into batches."""
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
|
||||||
|
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
|
||||||
|
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
|
||||||
|
{"doc_id": "id4", "content": "Document 4", "metadata": {"source": "test4"}},
|
||||||
|
{"doc_id": "id5", "content": "Document 5", "metadata": {"source": "test5"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_batch_size.add_documents(
|
||||||
|
collection_name="test_collection", documents=documents
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_collection.upsert.call_count == 3
|
||||||
|
|
||||||
|
first_call = mock_collection.upsert.call_args_list[0]
|
||||||
|
assert first_call.kwargs["ids"] == ["id1", "id2"]
|
||||||
|
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
|
||||||
|
assert first_call.kwargs["metadatas"] == [
|
||||||
|
{"source": "test1"},
|
||||||
|
{"source": "test2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
second_call = mock_collection.upsert.call_args_list[1]
|
||||||
|
assert second_call.kwargs["ids"] == ["id3", "id4"]
|
||||||
|
assert second_call.kwargs["documents"] == ["Document 3", "Document 4"]
|
||||||
|
assert second_call.kwargs["metadatas"] == [
|
||||||
|
{"source": "test3"},
|
||||||
|
{"source": "test4"},
|
||||||
|
]
|
||||||
|
|
||||||
|
third_call = mock_collection.upsert.call_args_list[2]
|
||||||
|
assert third_call.kwargs["ids"] == ["id5"]
|
||||||
|
assert third_call.kwargs["documents"] == ["Document 5"]
|
||||||
|
assert third_call.kwargs["metadatas"] == [{"source": "test5"}]
|
||||||
|
|
||||||
|
def test_add_documents_with_explicit_batch_size(
|
||||||
|
self, client, mock_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test add_documents with explicitly provided batch size."""
|
||||||
|
mock_collection = Mock()
|
||||||
|
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{"doc_id": "id1", "content": "Document 1"},
|
||||||
|
{"doc_id": "id2", "content": "Document 2"},
|
||||||
|
{"doc_id": "id3", "content": "Document 3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
client.add_documents(
|
||||||
|
collection_name="test_collection", documents=documents, batch_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_collection.upsert.call_count == 3
|
||||||
|
for i, call in enumerate(mock_collection.upsert.call_args_list):
|
||||||
|
assert len(call.kwargs["ids"]) == 1
|
||||||
|
assert call.kwargs["ids"] == [f"id{i + 1}"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_documents_with_batch_size(
|
||||||
|
self, async_client_with_batch_size, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test aadd_documents with batch size splits documents into batches."""
|
||||||
|
mock_collection = AsyncMock()
|
||||||
|
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||||
|
return_value=mock_collection
|
||||||
|
)
|
||||||
|
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
|
||||||
|
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
|
||||||
|
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
|
||||||
|
]
|
||||||
|
|
||||||
|
await async_client_with_batch_size.aadd_documents(
|
||||||
|
collection_name="test_collection", documents=documents
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_collection.upsert.call_count == 2
|
||||||
|
|
||||||
|
first_call = mock_collection.upsert.call_args_list[0]
|
||||||
|
assert first_call.kwargs["ids"] == ["id1", "id2"]
|
||||||
|
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
|
||||||
|
|
||||||
|
second_call = mock_collection.upsert.call_args_list[1]
|
||||||
|
assert second_call.kwargs["ids"] == ["id3"]
|
||||||
|
assert second_call.kwargs["documents"] == ["Document 3"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_aadd_documents_with_explicit_batch_size(
|
||||||
|
self, async_client, mock_async_chromadb_client
|
||||||
|
) -> None:
|
||||||
|
"""Test aadd_documents with explicitly provided batch size."""
|
||||||
|
mock_collection = AsyncMock()
|
||||||
|
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||||
|
return_value=mock_collection
|
||||||
|
)
|
||||||
|
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{"doc_id": "id1", "content": "Document 1"},
|
||||||
|
{"doc_id": "id2", "content": "Document 2"},
|
||||||
|
{"doc_id": "id3", "content": "Document 3"},
|
||||||
|
{"doc_id": "id4", "content": "Document 4"},
|
||||||
|
]
|
||||||
|
|
||||||
|
await async_client.aadd_documents(
|
||||||
|
collection_name="test_collection", documents=documents, batch_size=3
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_collection.upsert.call_count == 2
|
||||||
|
|
||||||
|
first_call = mock_collection.upsert.call_args_list[0]
|
||||||
|
assert len(first_call.kwargs["ids"]) == 3
|
||||||
|
|
||||||
|
second_call = mock_collection.upsert.call_args_list[1]
|
||||||
|
assert len(second_call.kwargs["ids"]) == 1
|
||||||
|
|
||||||
|
def test_client_default_batch_size_initialization(self) -> None:
|
||||||
|
"""Test that client initializes with correct default batch size."""
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_embedding = Mock()
|
||||||
|
|
||||||
|
client = ChromaDBClient(client=mock_client, embedding_function=mock_embedding)
|
||||||
|
assert client.default_batch_size == 100
|
||||||
|
|
||||||
|
custom_client = ChromaDBClient(
|
||||||
|
client=mock_client, embedding_function=mock_embedding, default_batch_size=50
|
||||||
|
)
|
||||||
|
assert custom_client.default_batch_size == 50
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
"""Tests for ChromaDB utility functions."""
|
"""Tests for ChromaDB utility functions."""
|
||||||
|
|
||||||
|
from crewai.rag.chromadb.types import PreparedDocuments
|
||||||
from crewai.rag.chromadb.utils import (
|
from crewai.rag.chromadb.utils import (
|
||||||
MAX_COLLECTION_LENGTH,
|
MAX_COLLECTION_LENGTH,
|
||||||
MIN_COLLECTION_LENGTH,
|
MIN_COLLECTION_LENGTH,
|
||||||
|
_create_batch_slice,
|
||||||
_is_ipv4_pattern,
|
_is_ipv4_pattern,
|
||||||
|
_prepare_documents_for_chromadb,
|
||||||
_sanitize_collection_name,
|
_sanitize_collection_name,
|
||||||
)
|
)
|
||||||
|
from crewai.rag.types import BaseRecord
|
||||||
|
|
||||||
|
|
||||||
class TestChromaDBUtils:
|
class TestChromaDBUtils:
|
||||||
@@ -93,3 +97,206 @@ class TestChromaDBUtils:
|
|||||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||||
assert sanitized[0].isalnum()
|
assert sanitized[0].isalnum()
|
||||||
assert sanitized[-1].isalnum()
|
assert sanitized[-1].isalnum()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrepareDocumentsForChromaDB:
|
||||||
|
"""Test suite for _prepare_documents_for_chromadb function."""
|
||||||
|
|
||||||
|
def test_prepare_documents_with_doc_ids(self) -> None:
|
||||||
|
"""Test preparing documents that already have doc_ids."""
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{
|
||||||
|
"doc_id": "id1",
|
||||||
|
"content": "First document",
|
||||||
|
"metadata": {"source": "test1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"doc_id": "id2",
|
||||||
|
"content": "Second document",
|
||||||
|
"metadata": {"source": "test2"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _prepare_documents_for_chromadb(documents)
|
||||||
|
|
||||||
|
assert result.ids == ["id1", "id2"]
|
||||||
|
assert result.texts == ["First document", "Second document"]
|
||||||
|
assert result.metadatas == [{"source": "test1"}, {"source": "test2"}]
|
||||||
|
|
||||||
|
def test_prepare_documents_generate_ids(self) -> None:
|
||||||
|
"""Test preparing documents without doc_ids (should generate hashes)."""
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{"content": "Test content", "metadata": {"key": "value"}},
|
||||||
|
{"content": "Another test"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _prepare_documents_for_chromadb(documents)
|
||||||
|
|
||||||
|
assert len(result.ids) == 2
|
||||||
|
assert all(len(doc_id) == 64 for doc_id in result.ids)
|
||||||
|
assert result.texts == ["Test content", "Another test"]
|
||||||
|
assert result.metadatas == [{"key": "value"}, {}]
|
||||||
|
|
||||||
|
def test_prepare_documents_with_list_metadata(self) -> None:
|
||||||
|
"""Test preparing documents with list metadata (should take first item)."""
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{"content": "Test", "metadata": [{"first": "item"}, {"second": "item"}]},
|
||||||
|
{"content": "Test2", "metadata": []},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _prepare_documents_for_chromadb(documents)
|
||||||
|
|
||||||
|
assert result.metadatas == [{"first": "item"}, {}]
|
||||||
|
|
||||||
|
def test_prepare_documents_no_metadata(self) -> None:
|
||||||
|
"""Test preparing documents without metadata."""
|
||||||
|
documents: list[BaseRecord] = [
|
||||||
|
{"content": "Document 1"},
|
||||||
|
{"content": "Document 2", "metadata": None},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = _prepare_documents_for_chromadb(documents)
|
||||||
|
|
||||||
|
assert result.metadatas == [{}, {}]
|
||||||
|
|
||||||
|
def test_prepare_documents_hash_consistency(self) -> None:
|
||||||
|
"""Test that identical content produces identical hashes."""
|
||||||
|
documents1: list[BaseRecord] = [
|
||||||
|
{"content": "Same content", "metadata": {"key": "value"}}
|
||||||
|
]
|
||||||
|
documents2: list[BaseRecord] = [
|
||||||
|
{"content": "Same content", "metadata": {"key": "value"}}
|
||||||
|
]
|
||||||
|
|
||||||
|
result1 = _prepare_documents_for_chromadb(documents1)
|
||||||
|
result2 = _prepare_documents_for_chromadb(documents2)
|
||||||
|
|
||||||
|
assert result1.ids == result2.ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateBatchSlice:
|
||||||
|
"""Test suite for _create_batch_slice function."""
|
||||||
|
|
||||||
|
def test_create_batch_slice_normal(self) -> None:
|
||||||
|
"""Test creating a normal batch slice."""
|
||||||
|
prepared = PreparedDocuments(
|
||||||
|
ids=["id1", "id2", "id3", "id4", "id5"],
|
||||||
|
texts=["doc1", "doc2", "doc3", "doc4", "doc5"],
|
||||||
|
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}, {"e": 5}],
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||||
|
prepared, start_index=1, batch_size=3
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batch_ids == ["id2", "id3", "id4"]
|
||||||
|
assert batch_texts == ["doc2", "doc3", "doc4"]
|
||||||
|
assert batch_metadatas == [{"b": 2}, {"c": 3}, {"d": 4}]
|
||||||
|
|
||||||
|
def test_create_batch_slice_at_end(self) -> None:
|
||||||
|
"""Test creating a batch slice that goes beyond the end."""
|
||||||
|
prepared = PreparedDocuments(
|
||||||
|
ids=["id1", "id2", "id3"],
|
||||||
|
texts=["doc1", "doc2", "doc3"],
|
||||||
|
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||||
|
prepared, start_index=2, batch_size=5
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batch_ids == ["id3"]
|
||||||
|
assert batch_texts == ["doc3"]
|
||||||
|
assert batch_metadatas == [{"c": 3}]
|
||||||
|
|
||||||
|
def test_create_batch_slice_empty_batch(self) -> None:
|
||||||
|
"""Test creating a batch slice starting beyond the data."""
|
||||||
|
prepared = PreparedDocuments(
|
||||||
|
ids=["id1", "id2"], texts=["doc1", "doc2"], metadatas=[{"a": 1}, {"b": 2}]
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||||
|
prepared, start_index=5, batch_size=3
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batch_ids == []
|
||||||
|
assert batch_texts == []
|
||||||
|
assert batch_metadatas == []
|
||||||
|
|
||||||
|
def test_create_batch_slice_no_metadatas(self) -> None:
|
||||||
|
"""Test creating a batch slice with no metadatas."""
|
||||||
|
prepared = PreparedDocuments(
|
||||||
|
ids=["id1", "id2", "id3"], texts=["doc1", "doc2", "doc3"], metadatas=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||||
|
prepared, start_index=0, batch_size=2
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batch_ids == ["id1", "id2"]
|
||||||
|
assert batch_texts == ["doc1", "doc2"]
|
||||||
|
assert batch_metadatas is None
|
||||||
|
|
||||||
|
def test_create_batch_slice_all_empty_metadatas(self) -> None:
|
||||||
|
"""Test creating a batch slice where all metadatas are empty."""
|
||||||
|
prepared = PreparedDocuments(
|
||||||
|
ids=["id1", "id2", "id3"],
|
||||||
|
texts=["doc1", "doc2", "doc3"],
|
||||||
|
metadatas=[{}, {}, {}],
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||||
|
prepared, start_index=0, batch_size=3
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batch_ids == ["id1", "id2", "id3"]
|
||||||
|
assert batch_texts == ["doc1", "doc2", "doc3"]
|
||||||
|
assert batch_metadatas is None
|
||||||
|
|
||||||
|
def test_create_batch_slice_some_empty_metadatas(self) -> None:
|
||||||
|
"""Test creating a batch slice where some metadatas are empty."""
|
||||||
|
prepared = PreparedDocuments(
|
||||||
|
ids=["id1", "id2", "id3"],
|
||||||
|
texts=["doc1", "doc2", "doc3"],
|
||||||
|
metadatas=[{"a": 1}, {}, {"c": 3}],
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||||
|
prepared, start_index=0, batch_size=3
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batch_ids == ["id1", "id2", "id3"]
|
||||||
|
assert batch_texts == ["doc1", "doc2", "doc3"]
|
||||||
|
assert batch_metadatas == [{"a": 1}, {}, {"c": 3}]
|
||||||
|
|
||||||
|
def test_create_batch_slice_zero_start_index(self) -> None:
|
||||||
|
"""Test creating a batch slice starting from index 0."""
|
||||||
|
prepared = PreparedDocuments(
|
||||||
|
ids=["id1", "id2", "id3", "id4"],
|
||||||
|
texts=["doc1", "doc2", "doc3", "doc4"],
|
||||||
|
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}],
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||||
|
prepared, start_index=0, batch_size=2
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batch_ids == ["id1", "id2"]
|
||||||
|
assert batch_texts == ["doc1", "doc2"]
|
||||||
|
assert batch_metadatas == [{"a": 1}, {"b": 2}]
|
||||||
|
|
||||||
|
def test_create_batch_slice_single_item(self) -> None:
|
||||||
|
"""Test creating a batch slice with batch size 1."""
|
||||||
|
prepared = PreparedDocuments(
|
||||||
|
ids=["id1", "id2", "id3"],
|
||||||
|
texts=["doc1", "doc2", "doc3"],
|
||||||
|
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||||
|
prepared, start_index=1, batch_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batch_ids == ["id2"]
|
||||||
|
assert batch_texts == ["doc2"]
|
||||||
|
assert batch_metadatas == [{"b": 2}]
|
||||||
|
|||||||
Reference in New Issue
Block a user