mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: add batch_size support to prevent embedder token limit errors
- add batch_size field to baseragconfig (default=100) - update chromadb/qdrant clients and factories to use batch_size - extract and filter batch_size from embedder config in knowledgestorage - fix large csv files exceeding embedder token limits (#3574) - remove unneeded conditional for type Co-authored-by: Vini Brasil <vini@hey.com>
This commit is contained in:
@@ -34,6 +34,30 @@ def client(mock_chromadb_client) -> ChromaDBClient:
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_with_batch_size(mock_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance with custom batch size for testing."""
|
||||
mock_embedding = Mock()
|
||||
client = ChromaDBClient(
|
||||
client=mock_chromadb_client,
|
||||
embedding_function=mock_embedding,
|
||||
default_batch_size=2,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client_with_batch_size(mock_async_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance with async client and custom batch size for testing."""
|
||||
mock_embedding = Mock()
|
||||
client = ChromaDBClient(
|
||||
client=mock_async_chromadb_client,
|
||||
embedding_function=mock_embedding,
|
||||
default_batch_size=2,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient instance with async client for testing."""
|
||||
@@ -612,3 +636,139 @@ class TestChromaDBClient:
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_chromadb_client.reset.assert_called_once_with()
|
||||
|
||||
def test_add_documents_with_batch_size(
|
||||
self, client_with_batch_size, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents with batch size splits documents into batches."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
|
||||
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
|
||||
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
|
||||
{"doc_id": "id4", "content": "Document 4", "metadata": {"source": "test4"}},
|
||||
{"doc_id": "id5", "content": "Document 5", "metadata": {"source": "test5"}},
|
||||
]
|
||||
|
||||
client_with_batch_size.add_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 3
|
||||
|
||||
first_call = mock_collection.upsert.call_args_list[0]
|
||||
assert first_call.kwargs["ids"] == ["id1", "id2"]
|
||||
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
|
||||
assert first_call.kwargs["metadatas"] == [
|
||||
{"source": "test1"},
|
||||
{"source": "test2"},
|
||||
]
|
||||
|
||||
second_call = mock_collection.upsert.call_args_list[1]
|
||||
assert second_call.kwargs["ids"] == ["id3", "id4"]
|
||||
assert second_call.kwargs["documents"] == ["Document 3", "Document 4"]
|
||||
assert second_call.kwargs["metadatas"] == [
|
||||
{"source": "test3"},
|
||||
{"source": "test4"},
|
||||
]
|
||||
|
||||
third_call = mock_collection.upsert.call_args_list[2]
|
||||
assert third_call.kwargs["ids"] == ["id5"]
|
||||
assert third_call.kwargs["documents"] == ["Document 5"]
|
||||
assert third_call.kwargs["metadatas"] == [{"source": "test5"}]
|
||||
|
||||
def test_add_documents_with_explicit_batch_size(
|
||||
self, client, mock_chromadb_client
|
||||
) -> None:
|
||||
"""Test add_documents with explicitly provided batch size."""
|
||||
mock_collection = Mock()
|
||||
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1"},
|
||||
{"doc_id": "id2", "content": "Document 2"},
|
||||
{"doc_id": "id3", "content": "Document 3"},
|
||||
]
|
||||
|
||||
client.add_documents(
|
||||
collection_name="test_collection", documents=documents, batch_size=1
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 3
|
||||
for i, call in enumerate(mock_collection.upsert.call_args_list):
|
||||
assert len(call.kwargs["ids"]) == 1
|
||||
assert call.kwargs["ids"] == [f"id{i + 1}"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_batch_size(
|
||||
self, async_client_with_batch_size, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with batch size splits documents into batches."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
|
||||
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
|
||||
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
|
||||
]
|
||||
|
||||
await async_client_with_batch_size.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 2
|
||||
|
||||
first_call = mock_collection.upsert.call_args_list[0]
|
||||
assert first_call.kwargs["ids"] == ["id1", "id2"]
|
||||
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
|
||||
|
||||
second_call = mock_collection.upsert.call_args_list[1]
|
||||
assert second_call.kwargs["ids"] == ["id3"]
|
||||
assert second_call.kwargs["documents"] == ["Document 3"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_explicit_batch_size(
|
||||
self, async_client, mock_async_chromadb_client
|
||||
) -> None:
|
||||
"""Test aadd_documents with explicitly provided batch size."""
|
||||
mock_collection = AsyncMock()
|
||||
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
|
||||
return_value=mock_collection
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{"doc_id": "id1", "content": "Document 1"},
|
||||
{"doc_id": "id2", "content": "Document 2"},
|
||||
{"doc_id": "id3", "content": "Document 3"},
|
||||
{"doc_id": "id4", "content": "Document 4"},
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents, batch_size=3
|
||||
)
|
||||
|
||||
assert mock_collection.upsert.call_count == 2
|
||||
|
||||
first_call = mock_collection.upsert.call_args_list[0]
|
||||
assert len(first_call.kwargs["ids"]) == 3
|
||||
|
||||
second_call = mock_collection.upsert.call_args_list[1]
|
||||
assert len(second_call.kwargs["ids"]) == 1
|
||||
|
||||
def test_client_default_batch_size_initialization(self) -> None:
|
||||
"""Test that client initializes with correct default batch size."""
|
||||
mock_client = Mock()
|
||||
mock_embedding = Mock()
|
||||
|
||||
client = ChromaDBClient(client=mock_client, embedding_function=mock_embedding)
|
||||
assert client.default_batch_size == 100
|
||||
|
||||
custom_client = ChromaDBClient(
|
||||
client=mock_client, embedding_function=mock_embedding, default_batch_size=50
|
||||
)
|
||||
assert custom_client.default_batch_size == 50
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""Tests for ChromaDB utility functions."""
|
||||
|
||||
from crewai.rag.chromadb.types import PreparedDocuments
|
||||
from crewai.rag.chromadb.utils import (
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
_create_batch_slice,
|
||||
_is_ipv4_pattern,
|
||||
_prepare_documents_for_chromadb,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
|
||||
class TestChromaDBUtils:
|
||||
@@ -93,3 +97,206 @@ class TestChromaDBUtils:
|
||||
assert len(sanitized) >= MIN_COLLECTION_LENGTH
|
||||
assert sanitized[0].isalnum()
|
||||
assert sanitized[-1].isalnum()
|
||||
|
||||
|
||||
class TestPrepareDocumentsForChromaDB:
|
||||
"""Test suite for _prepare_documents_for_chromadb function."""
|
||||
|
||||
def test_prepare_documents_with_doc_ids(self) -> None:
|
||||
"""Test preparing documents that already have doc_ids."""
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "id1",
|
||||
"content": "First document",
|
||||
"metadata": {"source": "test1"},
|
||||
},
|
||||
{
|
||||
"doc_id": "id2",
|
||||
"content": "Second document",
|
||||
"metadata": {"source": "test2"},
|
||||
},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert result.ids == ["id1", "id2"]
|
||||
assert result.texts == ["First document", "Second document"]
|
||||
assert result.metadatas == [{"source": "test1"}, {"source": "test2"}]
|
||||
|
||||
def test_prepare_documents_generate_ids(self) -> None:
|
||||
"""Test preparing documents without doc_ids (should generate hashes)."""
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Test content", "metadata": {"key": "value"}},
|
||||
{"content": "Another test"},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert len(result.ids) == 2
|
||||
assert all(len(doc_id) == 64 for doc_id in result.ids)
|
||||
assert result.texts == ["Test content", "Another test"]
|
||||
assert result.metadatas == [{"key": "value"}, {}]
|
||||
|
||||
def test_prepare_documents_with_list_metadata(self) -> None:
|
||||
"""Test preparing documents with list metadata (should take first item)."""
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Test", "metadata": [{"first": "item"}, {"second": "item"}]},
|
||||
{"content": "Test2", "metadata": []},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert result.metadatas == [{"first": "item"}, {}]
|
||||
|
||||
def test_prepare_documents_no_metadata(self) -> None:
|
||||
"""Test preparing documents without metadata."""
|
||||
documents: list[BaseRecord] = [
|
||||
{"content": "Document 1"},
|
||||
{"content": "Document 2", "metadata": None},
|
||||
]
|
||||
|
||||
result = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
assert result.metadatas == [{}, {}]
|
||||
|
||||
def test_prepare_documents_hash_consistency(self) -> None:
|
||||
"""Test that identical content produces identical hashes."""
|
||||
documents1: list[BaseRecord] = [
|
||||
{"content": "Same content", "metadata": {"key": "value"}}
|
||||
]
|
||||
documents2: list[BaseRecord] = [
|
||||
{"content": "Same content", "metadata": {"key": "value"}}
|
||||
]
|
||||
|
||||
result1 = _prepare_documents_for_chromadb(documents1)
|
||||
result2 = _prepare_documents_for_chromadb(documents2)
|
||||
|
||||
assert result1.ids == result2.ids
|
||||
|
||||
|
||||
class TestCreateBatchSlice:
|
||||
"""Test suite for _create_batch_slice function."""
|
||||
|
||||
def test_create_batch_slice_normal(self) -> None:
|
||||
"""Test creating a normal batch slice."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3", "id4", "id5"],
|
||||
texts=["doc1", "doc2", "doc3", "doc4", "doc5"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}, {"e": 5}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=1, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == ["id2", "id3", "id4"]
|
||||
assert batch_texts == ["doc2", "doc3", "doc4"]
|
||||
assert batch_metadatas == [{"b": 2}, {"c": 3}, {"d": 4}]
|
||||
|
||||
def test_create_batch_slice_at_end(self) -> None:
|
||||
"""Test creating a batch slice that goes beyond the end."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=2, batch_size=5
|
||||
)
|
||||
|
||||
assert batch_ids == ["id3"]
|
||||
assert batch_texts == ["doc3"]
|
||||
assert batch_metadatas == [{"c": 3}]
|
||||
|
||||
def test_create_batch_slice_empty_batch(self) -> None:
|
||||
"""Test creating a batch slice starting beyond the data."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2"], texts=["doc1", "doc2"], metadatas=[{"a": 1}, {"b": 2}]
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=5, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == []
|
||||
assert batch_texts == []
|
||||
assert batch_metadatas == []
|
||||
|
||||
def test_create_batch_slice_no_metadatas(self) -> None:
|
||||
"""Test creating a batch slice with no metadatas."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"], texts=["doc1", "doc2", "doc3"], metadatas=[]
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=2
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2"]
|
||||
assert batch_texts == ["doc1", "doc2"]
|
||||
assert batch_metadatas is None
|
||||
|
||||
def test_create_batch_slice_all_empty_metadatas(self) -> None:
|
||||
"""Test creating a batch slice where all metadatas are empty."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{}, {}, {}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2", "id3"]
|
||||
assert batch_texts == ["doc1", "doc2", "doc3"]
|
||||
assert batch_metadatas is None
|
||||
|
||||
def test_create_batch_slice_some_empty_metadatas(self) -> None:
|
||||
"""Test creating a batch slice where some metadatas are empty."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{"a": 1}, {}, {"c": 3}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=3
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2", "id3"]
|
||||
assert batch_texts == ["doc1", "doc2", "doc3"]
|
||||
assert batch_metadatas == [{"a": 1}, {}, {"c": 3}]
|
||||
|
||||
def test_create_batch_slice_zero_start_index(self) -> None:
|
||||
"""Test creating a batch slice starting from index 0."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3", "id4"],
|
||||
texts=["doc1", "doc2", "doc3", "doc4"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}, {"d": 4}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=0, batch_size=2
|
||||
)
|
||||
|
||||
assert batch_ids == ["id1", "id2"]
|
||||
assert batch_texts == ["doc1", "doc2"]
|
||||
assert batch_metadatas == [{"a": 1}, {"b": 2}]
|
||||
|
||||
def test_create_batch_slice_single_item(self) -> None:
|
||||
"""Test creating a batch slice with batch size 1."""
|
||||
prepared = PreparedDocuments(
|
||||
ids=["id1", "id2", "id3"],
|
||||
texts=["doc1", "doc2", "doc3"],
|
||||
metadatas=[{"a": 1}, {"b": 2}, {"c": 3}],
|
||||
)
|
||||
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared, start_index=1, batch_size=1
|
||||
)
|
||||
|
||||
assert batch_ids == ["id2"]
|
||||
assert batch_texts == ["doc2"]
|
||||
assert batch_metadatas == [{"b": 2}]
|
||||
|
||||
Reference in New Issue
Block a user