mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Fix CSVKnowledgeSource token limit issue with batching
- Add batch_size parameter to BaseFileKnowledgeSource (default: 50) - Modify _save_documents to process chunks in batches - Add comprehensive tests for large file handling and batching - Ensure backward compatibility with existing code Fixes #3574 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -23,6 +23,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
content: dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
batch_size: int = Field(
|
||||
default=50,
|
||||
description="Number of chunks to process in each batch to avoid token limits",
|
||||
)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
@@ -66,9 +70,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
)
|
||||
|
||||
def _save_documents(self):
|
||||
"""Save the documents to the storage."""
|
||||
"""Save the documents to the storage in batches to avoid token limits."""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
for i in range(0, len(self.chunks), self.batch_size):
|
||||
batch = self.chunks[i : i + self.batch_size]
|
||||
self.storage.save(batch)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
|
||||
@@ -602,3 +602,81 @@ def test_file_path_validation():
|
||||
match="file_path/file_paths must be a Path, str, or a list of these types",
|
||||
):
|
||||
PDFKnowledgeSource()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_csv_knowledge_source_large_file_batching(mock_vector_db, tmpdir):
|
||||
"""Test CSVKnowledgeSource with a large CSV file that would exceed token limits."""
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Create a large CSV file that would exceed token limits
|
||||
large_csv_content = [["Name", "Description", "Details", "Notes", "Extra"]]
|
||||
|
||||
for i in range(200): # This should generate enough content to test batching
|
||||
row = [
|
||||
f"Item_{i}",
|
||||
f"This is a detailed description for item {i} with lots of text content that will contribute to token count",
|
||||
f"Extended details about item {i} including technical specifications, usage instructions, and comprehensive information that adds to the overall token count when processed by the embedder",
|
||||
f"Additional notes and commentary for item {i} with even more text to ensure we have substantial content",
|
||||
f"Extra field with supplementary information for item {i} to maximize content size",
|
||||
]
|
||||
large_csv_content.append(row)
|
||||
|
||||
csv_path = Path(tmpdir.join("large_data.csv"))
|
||||
with open(csv_path, "w", encoding="utf-8") as f:
|
||||
for row in large_csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
# Create a CSVKnowledgeSource with custom batch size
|
||||
csv_source = CSVKnowledgeSource(
|
||||
file_paths=[csv_path],
|
||||
batch_size=25, # Smaller batch size for testing
|
||||
metadata={"test": "large_file"},
|
||||
)
|
||||
|
||||
# Mock the storage to track batch calls
|
||||
mock_storage = Mock()
|
||||
csv_source.storage = mock_storage
|
||||
|
||||
csv_source.add()
|
||||
|
||||
# Verify that storage.save was called multiple times (indicating batching)
|
||||
assert mock_storage.save.call_count > 1, (
|
||||
"Storage.save should be called multiple times for batching"
|
||||
)
|
||||
|
||||
# Verify that each batch has the expected size or less
|
||||
for call in mock_storage.save.call_args_list:
|
||||
batch_chunks = call[0][0] # First argument to save()
|
||||
assert len(batch_chunks) <= 25, (
|
||||
f"Batch size should not exceed 25, got {len(batch_chunks)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_csv_knowledge_source_default_batch_size(mock_vector_db, tmpdir):
|
||||
"""Test CSVKnowledgeSource uses default batch size when not specified."""
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Create a small CSV file
|
||||
csv_content = [
|
||||
["Name", "Age", "City"],
|
||||
["Alice", "25", "Boston"],
|
||||
["Bob", "30", "Seattle"],
|
||||
]
|
||||
csv_path = Path(tmpdir.join("small_data.csv"))
|
||||
with open(csv_path, "w", encoding="utf-8") as f:
|
||||
for row in csv_content:
|
||||
f.write(",".join(row) + "\n")
|
||||
|
||||
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
|
||||
|
||||
assert csv_source.batch_size == 50, (
|
||||
f"Default batch_size should be 50, got {csv_source.batch_size}"
|
||||
)
|
||||
|
||||
mock_storage = Mock()
|
||||
csv_source.storage = mock_storage
|
||||
csv_source.add()
|
||||
|
||||
assert mock_storage.save.called, "Storage.save should be called"
|
||||
|
||||
Reference in New Issue
Block a user