mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 23:28:30 +00:00
Fix CSVKnowledgeSource token limit issue with batching
- Add batch_size parameter to BaseFileKnowledgeSource (default: 50) - Modify _save_documents to process chunks in batches - Add comprehensive tests for large file handling and batching - Ensure backward compatibility with existing code Fixes #3574 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -23,6 +23,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
content: dict[Path, str] = Field(init=False, default_factory=dict)
|
content: dict[Path, str] = Field(init=False, default_factory=dict)
|
||||||
storage: KnowledgeStorage | None = Field(default=None)
|
storage: KnowledgeStorage | None = Field(default=None)
|
||||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||||
|
batch_size: int = Field(
|
||||||
|
default=50,
|
||||||
|
description="Number of chunks to process in each batch to avoid token limits",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("file_path", "file_paths", mode="before")
|
@field_validator("file_path", "file_paths", mode="before")
|
||||||
def validate_file_path(cls, v, info): # noqa: N805
|
def validate_file_path(cls, v, info): # noqa: N805
|
||||||
@@ -66,9 +70,11 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _save_documents(self):
|
def _save_documents(self):
|
||||||
"""Save the documents to the storage."""
|
"""Save the documents to the storage in batches to avoid token limits."""
|
||||||
if self.storage:
|
if self.storage:
|
||||||
self.storage.save(self.chunks)
|
for i in range(0, len(self.chunks), self.batch_size):
|
||||||
|
batch = self.chunks[i : i + self.batch_size]
|
||||||
|
self.storage.save(batch)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No storage found to save documents.")
|
raise ValueError("No storage found to save documents.")
|
||||||
|
|
||||||
|
|||||||
@@ -602,3 +602,81 @@ def test_file_path_validation():
|
|||||||
match="file_path/file_paths must be a Path, str, or a list of these types",
|
match="file_path/file_paths must be a Path, str, or a list of these types",
|
||||||
):
|
):
|
||||||
PDFKnowledgeSource()
|
PDFKnowledgeSource()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_csv_knowledge_source_large_file_batching(mock_vector_db, tmpdir):
|
||||||
|
"""Test CSVKnowledgeSource with a large CSV file that would exceed token limits."""
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
# Create a large CSV file that would exceed token limits
|
||||||
|
large_csv_content = [["Name", "Description", "Details", "Notes", "Extra"]]
|
||||||
|
|
||||||
|
for i in range(200): # This should generate enough content to test batching
|
||||||
|
row = [
|
||||||
|
f"Item_{i}",
|
||||||
|
f"This is a detailed description for item {i} with lots of text content that will contribute to token count",
|
||||||
|
f"Extended details about item {i} including technical specifications, usage instructions, and comprehensive information that adds to the overall token count when processed by the embedder",
|
||||||
|
f"Additional notes and commentary for item {i} with even more text to ensure we have substantial content",
|
||||||
|
f"Extra field with supplementary information for item {i} to maximize content size",
|
||||||
|
]
|
||||||
|
large_csv_content.append(row)
|
||||||
|
|
||||||
|
csv_path = Path(tmpdir.join("large_data.csv"))
|
||||||
|
with open(csv_path, "w", encoding="utf-8") as f:
|
||||||
|
for row in large_csv_content:
|
||||||
|
f.write(",".join(row) + "\n")
|
||||||
|
|
||||||
|
# Create a CSVKnowledgeSource with custom batch size
|
||||||
|
csv_source = CSVKnowledgeSource(
|
||||||
|
file_paths=[csv_path],
|
||||||
|
batch_size=25, # Smaller batch size for testing
|
||||||
|
metadata={"test": "large_file"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the storage to track batch calls
|
||||||
|
mock_storage = Mock()
|
||||||
|
csv_source.storage = mock_storage
|
||||||
|
|
||||||
|
csv_source.add()
|
||||||
|
|
||||||
|
# Verify that storage.save was called multiple times (indicating batching)
|
||||||
|
assert mock_storage.save.call_count > 1, (
|
||||||
|
"Storage.save should be called multiple times for batching"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that each batch has the expected size or less
|
||||||
|
for call in mock_storage.save.call_args_list:
|
||||||
|
batch_chunks = call[0][0] # First argument to save()
|
||||||
|
assert len(batch_chunks) <= 25, (
|
||||||
|
f"Batch size should not exceed 25, got {len(batch_chunks)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_csv_knowledge_source_default_batch_size(mock_vector_db, tmpdir):
|
||||||
|
"""Test CSVKnowledgeSource uses default batch size when not specified."""
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
# Create a small CSV file
|
||||||
|
csv_content = [
|
||||||
|
["Name", "Age", "City"],
|
||||||
|
["Alice", "25", "Boston"],
|
||||||
|
["Bob", "30", "Seattle"],
|
||||||
|
]
|
||||||
|
csv_path = Path(tmpdir.join("small_data.csv"))
|
||||||
|
with open(csv_path, "w", encoding="utf-8") as f:
|
||||||
|
for row in csv_content:
|
||||||
|
f.write(",".join(row) + "\n")
|
||||||
|
|
||||||
|
csv_source = CSVKnowledgeSource(file_paths=[csv_path])
|
||||||
|
|
||||||
|
assert csv_source.batch_size == 50, (
|
||||||
|
f"Default batch_size should be 50, got {csv_source.batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_storage = Mock()
|
||||||
|
csv_source.storage = mock_storage
|
||||||
|
csv_source.add()
|
||||||
|
|
||||||
|
assert mock_storage.save.called, "Storage.save should be called"
|
||||||
|
|||||||
Reference in New Issue
Block a user