Compare commits

...

1 Commits

Author SHA1 Message Date
Devin AI
207079e562 Fix CSVKnowledgeSource token limit issue with batching
- Add batch_size parameter to BaseFileKnowledgeSource (default: 50)
- Modify _save_documents to process chunks in batches
- Add comprehensive tests for large file handling and batching
- Ensure backward compatibility with existing code

Fixes #3574

Co-Authored-By: João <joao@crewai.com>
2025-09-22 10:06:40 +00:00
2 changed files with 86 additions and 2 deletions

View File

@@ -23,6 +23,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
content: dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage | None = Field(default=None)
safe_file_paths: list[Path] = Field(default_factory=list)
batch_size: int = Field(
default=50,
description="Number of chunks to process in each batch to avoid token limits",
)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info): # noqa: N805
@@ -66,9 +70,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
)
def _save_documents(self):
"""Save the documents to the storage."""
"""Save the documents to the storage in batches to avoid token limits."""
if self.storage:
self.storage.save(self.chunks)
for i in range(0, len(self.chunks), self.batch_size):
batch = self.chunks[i : i + self.batch_size]
self.storage.save(batch)
else:
raise ValueError("No storage found to save documents.")

View File

@@ -602,3 +602,81 @@ def test_file_path_validation():
match="file_path/file_paths must be a Path, str, or a list of these types",
):
PDFKnowledgeSource()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_csv_knowledge_source_large_file_batching(mock_vector_db, tmpdir):
"""Test CSVKnowledgeSource with a large CSV file that would exceed token limits."""
from unittest.mock import Mock
# Create a large CSV file that would exceed token limits
large_csv_content = [["Name", "Description", "Details", "Notes", "Extra"]]
for i in range(200): # This should generate enough content to test batching
row = [
f"Item_{i}",
f"This is a detailed description for item {i} with lots of text content that will contribute to token count",
f"Extended details about item {i} including technical specifications, usage instructions, and comprehensive information that adds to the overall token count when processed by the embedder",
f"Additional notes and commentary for item {i} with even more text to ensure we have substantial content",
f"Extra field with supplementary information for item {i} to maximize content size",
]
large_csv_content.append(row)
csv_path = Path(tmpdir.join("large_data.csv"))
with open(csv_path, "w", encoding="utf-8") as f:
for row in large_csv_content:
f.write(",".join(row) + "\n")
# Create a CSVKnowledgeSource with custom batch size
csv_source = CSVKnowledgeSource(
file_paths=[csv_path],
batch_size=25, # Smaller batch size for testing
metadata={"test": "large_file"},
)
# Mock the storage to track batch calls
mock_storage = Mock()
csv_source.storage = mock_storage
csv_source.add()
# Verify that storage.save was called multiple times (indicating batching)
assert mock_storage.save.call_count > 1, (
"Storage.save should be called multiple times for batching"
)
# Verify that each batch has the expected size or less
for call in mock_storage.save.call_args_list:
batch_chunks = call[0][0] # First argument to save()
assert len(batch_chunks) <= 25, (
f"Batch size should not exceed 25, got {len(batch_chunks)}"
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_csv_knowledge_source_default_batch_size(mock_vector_db, tmpdir):
"""Test CSVKnowledgeSource uses default batch size when not specified."""
from unittest.mock import Mock
# Create a small CSV file
csv_content = [
["Name", "Age", "City"],
["Alice", "25", "Boston"],
["Bob", "30", "Seattle"],
]
csv_path = Path(tmpdir.join("small_data.csv"))
with open(csv_path, "w", encoding="utf-8") as f:
for row in csv_content:
f.write(",".join(row) + "\n")
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
assert csv_source.batch_size == 50, (
f"Default batch_size should be 50, got {csv_source.batch_size}"
)
mock_storage = Mock()
csv_source.storage = mock_storage
csv_source.add()
assert mock_storage.save.called, "Storage.save should be called"