Fix CSVKnowledgeSource token limit issue with batching

- Add batch_size parameter to BaseFileKnowledgeSource (default: 50)
- Modify _save_documents to process chunks in batches
- Add comprehensive tests for large file handling and batching
- Ensure backward compatibility with existing code

Fixes #3574

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-09-22 10:06:40 +00:00
parent aa8dc9d77f
commit 207079e562
2 changed files with 86 additions and 2 deletions

View File

@@ -602,3 +602,81 @@ def test_file_path_validation():
match="file_path/file_paths must be a Path, str, or a list of these types",
):
PDFKnowledgeSource()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_csv_knowledge_source_large_file_batching(mock_vector_db, tmpdir):
"""Test CSVKnowledgeSource with a large CSV file that would exceed token limits."""
from unittest.mock import Mock
# Create a large CSV file that would exceed token limits
large_csv_content = [["Name", "Description", "Details", "Notes", "Extra"]]
for i in range(200): # This should generate enough content to test batching
row = [
f"Item_{i}",
f"This is a detailed description for item {i} with lots of text content that will contribute to token count",
f"Extended details about item {i} including technical specifications, usage instructions, and comprehensive information that adds to the overall token count when processed by the embedder",
f"Additional notes and commentary for item {i} with even more text to ensure we have substantial content",
f"Extra field with supplementary information for item {i} to maximize content size",
]
large_csv_content.append(row)
csv_path = Path(tmpdir.join("large_data.csv"))
with open(csv_path, "w", encoding="utf-8") as f:
for row in large_csv_content:
f.write(",".join(row) + "\n")
# Create a CSVKnowledgeSource with custom batch size
csv_source = CSVKnowledgeSource(
file_paths=[csv_path],
batch_size=25, # Smaller batch size for testing
metadata={"test": "large_file"},
)
# Mock the storage to track batch calls
mock_storage = Mock()
csv_source.storage = mock_storage
csv_source.add()
# Verify that storage.save was called multiple times (indicating batching)
assert mock_storage.save.call_count > 1, (
"Storage.save should be called multiple times for batching"
)
# Verify that each batch has the expected size or less
for call in mock_storage.save.call_args_list:
batch_chunks = call[0][0] # First argument to save()
assert len(batch_chunks) <= 25, (
f"Batch size should not exceed 25, got {len(batch_chunks)}"
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_csv_knowledge_source_default_batch_size(mock_vector_db, tmpdir):
"""Test CSVKnowledgeSource uses default batch size when not specified."""
from unittest.mock import Mock
# Create a small CSV file
csv_content = [
["Name", "Age", "City"],
["Alice", "25", "Boston"],
["Bob", "30", "Seattle"],
]
csv_path = Path(tmpdir.join("small_data.csv"))
with open(csv_path, "w", encoding="utf-8") as f:
for row in csv_content:
f.write(",".join(row) + "\n")
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
assert csv_source.batch_size == 50, (
f"Default batch_size should be 50, got {csv_source.batch_size}"
)
mock_storage = Mock()
csv_source.storage = mock_storage
csv_source.add()
assert mock_storage.save.called, "Storage.save should be called"