fix: add batch_size support to prevent embedder token limit errors

- add batch_size field to baseragconfig (default=100)  
- update chromadb/qdrant clients and factories to use batch_size  
- extract and filter batch_size from embedder config in knowledgestorage  
- fix large csv files exceeding embedder token limits (#3574)  
- remove unneeded conditional for type  

Co-authored-by: Vini Brasil <vini@hey.com>
This commit is contained in:
Greyson LaLonde
2025-09-24 00:05:43 -04:00
committed by GitHub
parent 4ac65eb0a6
commit 1dbe8aab52
12 changed files with 558 additions and 56 deletions

View File

@@ -34,6 +34,30 @@ def client(mock_chromadb_client) -> ChromaDBClient:
return client
@pytest.fixture
def client_with_batch_size(mock_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with custom batch size for testing."""
mock_embedding = Mock()
client = ChromaDBClient(
client=mock_chromadb_client,
embedding_function=mock_embedding,
default_batch_size=2,
)
return client
@pytest.fixture
def async_client_with_batch_size(mock_async_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with async client and custom batch size for testing."""
mock_embedding = Mock()
client = ChromaDBClient(
client=mock_async_chromadb_client,
embedding_function=mock_embedding,
default_batch_size=2,
)
return client
@pytest.fixture
def async_client(mock_async_chromadb_client) -> ChromaDBClient:
"""Create a ChromaDBClient instance with async client for testing."""
@@ -612,3 +636,139 @@ class TestChromaDBClient:
await async_client.areset()
mock_async_chromadb_client.reset.assert_called_once_with()
def test_add_documents_with_batch_size(
self, client_with_batch_size, mock_chromadb_client
) -> None:
"""Test add_documents with batch size splits documents into batches."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
{"doc_id": "id4", "content": "Document 4", "metadata": {"source": "test4"}},
{"doc_id": "id5", "content": "Document 5", "metadata": {"source": "test5"}},
]
client_with_batch_size.add_documents(
collection_name="test_collection", documents=documents
)
assert mock_collection.upsert.call_count == 3
first_call = mock_collection.upsert.call_args_list[0]
assert first_call.kwargs["ids"] == ["id1", "id2"]
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
assert first_call.kwargs["metadatas"] == [
{"source": "test1"},
{"source": "test2"},
]
second_call = mock_collection.upsert.call_args_list[1]
assert second_call.kwargs["ids"] == ["id3", "id4"]
assert second_call.kwargs["documents"] == ["Document 3", "Document 4"]
assert second_call.kwargs["metadatas"] == [
{"source": "test3"},
{"source": "test4"},
]
third_call = mock_collection.upsert.call_args_list[2]
assert third_call.kwargs["ids"] == ["id5"]
assert third_call.kwargs["documents"] == ["Document 5"]
assert third_call.kwargs["metadatas"] == [{"source": "test5"}]
def test_add_documents_with_explicit_batch_size(
self, client, mock_chromadb_client
) -> None:
"""Test add_documents with explicitly provided batch size."""
mock_collection = Mock()
mock_chromadb_client.get_or_create_collection.return_value = mock_collection
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1"},
{"doc_id": "id2", "content": "Document 2"},
{"doc_id": "id3", "content": "Document 3"},
]
client.add_documents(
collection_name="test_collection", documents=documents, batch_size=1
)
assert mock_collection.upsert.call_count == 3
for i, call in enumerate(mock_collection.upsert.call_args_list):
assert len(call.kwargs["ids"]) == 1
assert call.kwargs["ids"] == [f"id{i + 1}"]
@pytest.mark.asyncio
async def test_aadd_documents_with_batch_size(
self, async_client_with_batch_size, mock_async_chromadb_client
) -> None:
"""Test aadd_documents with batch size splits documents into batches."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1", "metadata": {"source": "test1"}},
{"doc_id": "id2", "content": "Document 2", "metadata": {"source": "test2"}},
{"doc_id": "id3", "content": "Document 3", "metadata": {"source": "test3"}},
]
await async_client_with_batch_size.aadd_documents(
collection_name="test_collection", documents=documents
)
assert mock_collection.upsert.call_count == 2
first_call = mock_collection.upsert.call_args_list[0]
assert first_call.kwargs["ids"] == ["id1", "id2"]
assert first_call.kwargs["documents"] == ["Document 1", "Document 2"]
second_call = mock_collection.upsert.call_args_list[1]
assert second_call.kwargs["ids"] == ["id3"]
assert second_call.kwargs["documents"] == ["Document 3"]
@pytest.mark.asyncio
async def test_aadd_documents_with_explicit_batch_size(
self, async_client, mock_async_chromadb_client
) -> None:
"""Test aadd_documents with explicitly provided batch size."""
mock_collection = AsyncMock()
mock_async_chromadb_client.get_or_create_collection = AsyncMock(
return_value=mock_collection
)
documents: list[BaseRecord] = [
{"doc_id": "id1", "content": "Document 1"},
{"doc_id": "id2", "content": "Document 2"},
{"doc_id": "id3", "content": "Document 3"},
{"doc_id": "id4", "content": "Document 4"},
]
await async_client.aadd_documents(
collection_name="test_collection", documents=documents, batch_size=3
)
assert mock_collection.upsert.call_count == 2
first_call = mock_collection.upsert.call_args_list[0]
assert len(first_call.kwargs["ids"]) == 3
second_call = mock_collection.upsert.call_args_list[1]
assert len(second_call.kwargs["ids"]) == 1
def test_client_default_batch_size_initialization(self) -> None:
"""Test that client initializes with correct default batch size."""
mock_client = Mock()
mock_embedding = Mock()
client = ChromaDBClient(client=mock_client, embedding_function=mock_embedding)
assert client.default_batch_size == 100
custom_client = ChromaDBClient(
client=mock_client, embedding_function=mock_embedding, default_batch_size=50
)
assert custom_client.default_batch_size == 50