refactor: extract files module to standalone crewai-files package

This commit is contained in:
Greyson LaLonde
2026-01-22 15:06:20 -05:00
parent a064b84ead
commit b95a3a9bc8
62 changed files with 639 additions and 582 deletions

View File

@@ -1 +0,0 @@
"""Tests for file processing utilities."""

View File

@@ -1 +0,0 @@
"""Tests for file processing module."""

View File

@@ -1,226 +0,0 @@
"""Tests for provider constraints."""
import pytest
from crewai.files.processing.constraints import (
ANTHROPIC_CONSTRAINTS,
BEDROCK_CONSTRAINTS,
GEMINI_CONSTRAINTS,
OPENAI_CONSTRAINTS,
AudioConstraints,
ImageConstraints,
PDFConstraints,
ProviderConstraints,
VideoConstraints,
get_constraints_for_provider,
)
class TestImageConstraints:
"""Tests for ImageConstraints dataclass."""
def test_image_constraints_creation(self):
"""Test creating image constraints with all fields."""
constraints = ImageConstraints(
max_size_bytes=5 * 1024 * 1024,
max_width=8000,
max_height=8000,
max_images_per_request=10,
)
assert constraints.max_size_bytes == 5 * 1024 * 1024
assert constraints.max_width == 8000
assert constraints.max_height == 8000
assert constraints.max_images_per_request == 10
def test_image_constraints_defaults(self):
"""Test image constraints with default values."""
constraints = ImageConstraints(max_size_bytes=1000)
assert constraints.max_size_bytes == 1000
assert constraints.max_width is None
assert constraints.max_height is None
assert constraints.max_images_per_request is None
assert "image/png" in constraints.supported_formats
def test_image_constraints_frozen(self):
"""Test that image constraints are immutable."""
constraints = ImageConstraints(max_size_bytes=1000)
with pytest.raises(Exception):
constraints.max_size_bytes = 2000
class TestPDFConstraints:
"""Tests for PDFConstraints dataclass."""
def test_pdf_constraints_creation(self):
"""Test creating PDF constraints."""
constraints = PDFConstraints(
max_size_bytes=30 * 1024 * 1024,
max_pages=100,
)
assert constraints.max_size_bytes == 30 * 1024 * 1024
assert constraints.max_pages == 100
def test_pdf_constraints_defaults(self):
"""Test PDF constraints with default values."""
constraints = PDFConstraints(max_size_bytes=1000)
assert constraints.max_size_bytes == 1000
assert constraints.max_pages is None
class TestAudioConstraints:
"""Tests for AudioConstraints dataclass."""
def test_audio_constraints_creation(self):
"""Test creating audio constraints."""
constraints = AudioConstraints(
max_size_bytes=100 * 1024 * 1024,
max_duration_seconds=3600,
)
assert constraints.max_size_bytes == 100 * 1024 * 1024
assert constraints.max_duration_seconds == 3600
assert "audio/mp3" in constraints.supported_formats
class TestVideoConstraints:
"""Tests for VideoConstraints dataclass."""
def test_video_constraints_creation(self):
"""Test creating video constraints."""
constraints = VideoConstraints(
max_size_bytes=2 * 1024 * 1024 * 1024,
max_duration_seconds=7200,
)
assert constraints.max_size_bytes == 2 * 1024 * 1024 * 1024
assert constraints.max_duration_seconds == 7200
assert "video/mp4" in constraints.supported_formats
class TestProviderConstraints:
"""Tests for ProviderConstraints dataclass."""
def test_provider_constraints_creation(self):
"""Test creating full provider constraints."""
constraints = ProviderConstraints(
name="test-provider",
image=ImageConstraints(max_size_bytes=5 * 1024 * 1024),
pdf=PDFConstraints(max_size_bytes=30 * 1024 * 1024),
supports_file_upload=True,
file_upload_threshold_bytes=10 * 1024 * 1024,
)
assert constraints.name == "test-provider"
assert constraints.image is not None
assert constraints.pdf is not None
assert constraints.supports_file_upload is True
def test_provider_constraints_defaults(self):
"""Test provider constraints with default values."""
constraints = ProviderConstraints(name="test")
assert constraints.name == "test"
assert constraints.image is None
assert constraints.pdf is None
assert constraints.audio is None
assert constraints.video is None
assert constraints.supports_file_upload is False
class TestPredefinedConstraints:
"""Tests for predefined provider constraints."""
def test_anthropic_constraints(self):
"""Test Anthropic constraints are properly defined."""
assert ANTHROPIC_CONSTRAINTS.name == "anthropic"
assert ANTHROPIC_CONSTRAINTS.image is not None
assert ANTHROPIC_CONSTRAINTS.image.max_size_bytes == 5 * 1024 * 1024
assert ANTHROPIC_CONSTRAINTS.image.max_width == 8000
assert ANTHROPIC_CONSTRAINTS.pdf is not None
assert ANTHROPIC_CONSTRAINTS.pdf.max_pages == 100
assert ANTHROPIC_CONSTRAINTS.supports_file_upload is True
def test_openai_constraints(self):
"""Test OpenAI constraints are properly defined."""
assert OPENAI_CONSTRAINTS.name == "openai"
assert OPENAI_CONSTRAINTS.image is not None
assert OPENAI_CONSTRAINTS.image.max_size_bytes == 20 * 1024 * 1024
assert OPENAI_CONSTRAINTS.pdf is None # OpenAI doesn't support PDFs
def test_gemini_constraints(self):
"""Test Gemini constraints are properly defined."""
assert GEMINI_CONSTRAINTS.name == "gemini"
assert GEMINI_CONSTRAINTS.image is not None
assert GEMINI_CONSTRAINTS.pdf is not None
assert GEMINI_CONSTRAINTS.audio is not None
assert GEMINI_CONSTRAINTS.video is not None
assert GEMINI_CONSTRAINTS.supports_file_upload is True
def test_bedrock_constraints(self):
"""Test Bedrock constraints are properly defined."""
assert BEDROCK_CONSTRAINTS.name == "bedrock"
assert BEDROCK_CONSTRAINTS.image is not None
assert BEDROCK_CONSTRAINTS.image.max_size_bytes == 4_608_000
assert BEDROCK_CONSTRAINTS.pdf is not None
assert BEDROCK_CONSTRAINTS.supports_file_upload is False
class TestGetConstraintsForProvider:
"""Tests for get_constraints_for_provider function."""
def test_get_by_exact_name(self):
"""Test getting constraints by exact provider name."""
result = get_constraints_for_provider("anthropic")
assert result == ANTHROPIC_CONSTRAINTS
result = get_constraints_for_provider("openai")
assert result == OPENAI_CONSTRAINTS
result = get_constraints_for_provider("gemini")
assert result == GEMINI_CONSTRAINTS
def test_get_by_alias(self):
"""Test getting constraints by alias name."""
result = get_constraints_for_provider("claude")
assert result == ANTHROPIC_CONSTRAINTS
result = get_constraints_for_provider("gpt")
assert result == OPENAI_CONSTRAINTS
result = get_constraints_for_provider("google")
assert result == GEMINI_CONSTRAINTS
def test_get_case_insensitive(self):
"""Test case-insensitive lookup."""
result = get_constraints_for_provider("ANTHROPIC")
assert result == ANTHROPIC_CONSTRAINTS
result = get_constraints_for_provider("OpenAI")
assert result == OPENAI_CONSTRAINTS
def test_get_with_provider_constraints_object(self):
"""Test passing ProviderConstraints object returns it unchanged."""
custom = ProviderConstraints(name="custom")
result = get_constraints_for_provider(custom)
assert result is custom
def test_get_unknown_provider(self):
"""Test unknown provider returns None."""
result = get_constraints_for_provider("unknown-provider")
assert result is None
def test_get_by_partial_match(self):
"""Test partial match in provider string."""
result = get_constraints_for_provider("claude-3-sonnet")
assert result == ANTHROPIC_CONSTRAINTS
result = get_constraints_for_provider("gpt-4o")
assert result == OPENAI_CONSTRAINTS
result = get_constraints_for_provider("gemini-pro")
assert result == GEMINI_CONSTRAINTS

View File

@@ -1,220 +0,0 @@
"""Tests for FileProcessor class."""
import pytest
from crewai.files import FileBytes, ImageFile, PDFFile, TextFile
from crewai.files.processing.constraints import (
ANTHROPIC_CONSTRAINTS,
ImageConstraints,
PDFConstraints,
ProviderConstraints,
)
from crewai.files.processing.enums import FileHandling
from crewai.files.processing.exceptions import (
FileTooLargeError,
FileValidationError,
)
from crewai.files.processing.processor import FileProcessor
# Minimal valid PNG: 8x8 pixel RGB image (valid for PIL)
MINIMAL_PNG = bytes([
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d,
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08,
0x08, 0x02, 0x00, 0x00, 0x00, 0x4b, 0x6d, 0x29, 0xdc, 0x00, 0x00, 0x00,
0x12, 0x49, 0x44, 0x41, 0x54, 0x78, 0x9c, 0x63, 0xfc, 0xcf, 0x80, 0x1d,
0x30, 0xe1, 0x10, 0x1f, 0xa4, 0x12, 0x00, 0xcd, 0x41, 0x01, 0x0f, 0xe8,
0x41, 0xe2, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae,
0x42, 0x60, 0x82,
])
# Minimal valid PDF
MINIMAL_PDF = (
b"%PDF-1.4\n1 0 obj<</Type/Catalog/Pages 2 0 R>>endobj "
b"2 0 obj<</Type/Pages/Kids[3 0 R]/Count 1>>endobj "
b"3 0 obj<</Type/Page/MediaBox[0 0 612 792]/Parent 2 0 R>>endobj "
b"xref\n0 4\n0000000000 65535 f \n0000000009 00000 n \n"
b"0000000052 00000 n \n0000000101 00000 n \n"
b"trailer<</Size 4/Root 1 0 R>>\nstartxref\n178\n%%EOF"
)
class TestFileProcessorInit:
"""Tests for FileProcessor initialization."""
def test_init_with_constraints(self):
"""Test initialization with ProviderConstraints."""
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
assert processor.constraints == ANTHROPIC_CONSTRAINTS
def test_init_with_provider_string(self):
"""Test initialization with provider name string."""
processor = FileProcessor(constraints="anthropic")
assert processor.constraints == ANTHROPIC_CONSTRAINTS
def test_init_with_unknown_provider(self):
"""Test initialization with unknown provider sets constraints to None."""
processor = FileProcessor(constraints="unknown")
assert processor.constraints is None
def test_init_with_none_constraints(self):
"""Test initialization with None constraints."""
processor = FileProcessor(constraints=None)
assert processor.constraints is None
class TestFileProcessorValidate:
"""Tests for FileProcessor.validate method."""
def test_validate_valid_file(self):
"""Test validating a valid file returns no errors."""
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
errors = processor.validate(file)
assert len(errors) == 0
def test_validate_without_constraints(self):
"""Test validating without constraints returns empty list."""
processor = FileProcessor(constraints=None)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
errors = processor.validate(file)
assert len(errors) == 0
def test_validate_strict_raises_on_error(self):
"""Test STRICT mode raises on validation error."""
constraints = ProviderConstraints(
name="test",
image=ImageConstraints(max_size_bytes=10),
)
processor = FileProcessor(constraints=constraints)
# Set mode to strict on the file
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict")
with pytest.raises(FileTooLargeError):
processor.validate(file)
class TestFileProcessorProcess:
"""Tests for FileProcessor.process method."""
def test_process_valid_file(self):
"""Test processing a valid file returns it unchanged."""
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
result = processor.process(file)
assert result == file
def test_process_without_constraints(self):
"""Test processing without constraints returns file unchanged."""
processor = FileProcessor(constraints=None)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
result = processor.process(file)
assert result == file
def test_process_strict_raises_on_error(self):
"""Test STRICT mode raises on processing error."""
constraints = ProviderConstraints(
name="test",
image=ImageConstraints(max_size_bytes=10),
)
processor = FileProcessor(constraints=constraints)
# Set mode to strict on the file
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict")
with pytest.raises(FileTooLargeError):
processor.process(file)
def test_process_warn_returns_file(self):
"""Test WARN mode returns file with warning."""
constraints = ProviderConstraints(
name="test",
image=ImageConstraints(max_size_bytes=10),
)
processor = FileProcessor(constraints=constraints)
# Set mode to warn on the file
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="warn")
result = processor.process(file)
assert result == file
class TestFileProcessorProcessFiles:
"""Tests for FileProcessor.process_files method."""
def test_process_files_multiple(self):
"""Test processing multiple files."""
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
files = {
"image1": ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png")),
"image2": ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test2.png")),
}
result = processor.process_files(files)
assert len(result) == 2
assert "image1" in result
assert "image2" in result
def test_process_files_empty(self):
"""Test processing empty files dict."""
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
result = processor.process_files({})
assert result == {}
class TestFileHandlingEnum:
"""Tests for FileHandling enum."""
def test_enum_values(self):
"""Test all enum values are accessible."""
assert FileHandling.STRICT.value == "strict"
assert FileHandling.AUTO.value == "auto"
assert FileHandling.WARN.value == "warn"
assert FileHandling.CHUNK.value == "chunk"
class TestFileProcessorPerFileMode:
"""Tests for per-file mode handling."""
def test_file_default_mode_is_auto(self):
"""Test that files default to auto mode."""
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
assert file.mode == "auto"
def test_file_custom_mode(self):
"""Test setting custom mode on file."""
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict")
assert file.mode == "strict"
def test_processor_respects_file_mode(self):
"""Test processor uses each file's mode setting."""
constraints = ProviderConstraints(
name="test",
image=ImageConstraints(max_size_bytes=10),
)
processor = FileProcessor(constraints=constraints)
# File with strict mode should raise
strict_file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict")
with pytest.raises(FileTooLargeError):
processor.process(strict_file)
# File with warn mode should not raise
warn_file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="warn")
result = processor.process(warn_file)
assert result == warn_file

View File

@@ -1,359 +0,0 @@
"""Unit tests for file transformers."""
import io
from unittest.mock import MagicMock, patch
import pytest
from crewai.files import ImageFile, PDFFile, TextFile
from crewai.files.file import FileBytes
from crewai.files.processing.exceptions import ProcessingDependencyError
from crewai.files.processing.transformers import (
chunk_pdf,
chunk_text,
get_image_dimensions,
get_pdf_page_count,
optimize_image,
resize_image,
)
def create_test_png(width: int = 100, height: int = 100) -> bytes:
"""Create a minimal valid PNG for testing."""
from PIL import Image
img = Image.new("RGB", (width, height), color="red")
buffer = io.BytesIO()
img.save(buffer, format="PNG")
return buffer.getvalue()
def create_test_pdf(num_pages: int = 1) -> bytes:
"""Create a minimal valid PDF for testing."""
from pypdf import PdfWriter
writer = PdfWriter()
for _ in range(num_pages):
writer.add_blank_page(width=612, height=792)
buffer = io.BytesIO()
writer.write(buffer)
return buffer.getvalue()
class TestResizeImage:
"""Tests for resize_image function."""
def test_resize_larger_image(self) -> None:
"""Test resizing an image larger than max dimensions."""
png_bytes = create_test_png(200, 150)
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
result = resize_image(img, max_width=100, max_height=100)
dims = get_image_dimensions(result)
assert dims is not None
width, height = dims
assert width <= 100
assert height <= 100
def test_no_resize_if_within_bounds(self) -> None:
"""Test that small images are returned unchanged."""
png_bytes = create_test_png(50, 50)
img = ImageFile(source=FileBytes(data=png_bytes, filename="small.png"))
result = resize_image(img, max_width=100, max_height=100)
assert result is img
def test_preserve_aspect_ratio(self) -> None:
"""Test that aspect ratio is preserved during resize."""
png_bytes = create_test_png(200, 100)
img = ImageFile(source=FileBytes(data=png_bytes, filename="wide.png"))
result = resize_image(img, max_width=100, max_height=100)
dims = get_image_dimensions(result)
assert dims is not None
width, height = dims
assert width == 100
assert height == 50
def test_resize_without_aspect_ratio(self) -> None:
"""Test resizing without preserving aspect ratio."""
png_bytes = create_test_png(200, 100)
img = ImageFile(source=FileBytes(data=png_bytes, filename="wide.png"))
result = resize_image(
img, max_width=50, max_height=50, preserve_aspect_ratio=False
)
dims = get_image_dimensions(result)
assert dims is not None
width, height = dims
assert width == 50
assert height == 50
def test_resize_returns_image_file(self) -> None:
"""Test that resize returns an ImageFile instance."""
png_bytes = create_test_png(200, 200)
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
result = resize_image(img, max_width=100, max_height=100)
assert isinstance(result, ImageFile)
def test_raises_without_pillow(self) -> None:
"""Test that ProcessingDependencyError is raised without Pillow."""
img = ImageFile(source=FileBytes(data=b"fake", filename="test.png"))
with patch.dict("sys.modules", {"PIL": None, "PIL.Image": None}):
with pytest.raises(ProcessingDependencyError) as exc_info:
# Force reimport to trigger ImportError
import importlib
import crewai.files.processing.transformers as t
importlib.reload(t)
t.resize_image(img, 100, 100)
assert "Pillow" in str(exc_info.value)
class TestOptimizeImage:
"""Tests for optimize_image function."""
def test_optimize_reduces_size(self) -> None:
"""Test that optimization reduces file size."""
png_bytes = create_test_png(500, 500)
original_size = len(png_bytes)
img = ImageFile(source=FileBytes(data=png_bytes, filename="large.png"))
result = optimize_image(img, target_size_bytes=original_size // 2)
result_size = len(result.read())
assert result_size < original_size
def test_no_optimize_if_under_target(self) -> None:
"""Test that small images are returned unchanged."""
png_bytes = create_test_png(50, 50)
img = ImageFile(source=FileBytes(data=png_bytes, filename="small.png"))
result = optimize_image(img, target_size_bytes=1024 * 1024)
assert result is img
def test_optimize_returns_image_file(self) -> None:
"""Test that optimize returns an ImageFile instance."""
png_bytes = create_test_png(200, 200)
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
result = optimize_image(img, target_size_bytes=100)
assert isinstance(result, ImageFile)
def test_optimize_respects_min_quality(self) -> None:
"""Test that optimization stops at minimum quality."""
png_bytes = create_test_png(100, 100)
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
# Request impossibly small size - should stop at min quality
result = optimize_image(img, target_size_bytes=10, min_quality=50)
assert isinstance(result, ImageFile)
assert len(result.read()) > 10
class TestChunkPdf:
"""Tests for chunk_pdf function."""
def test_chunk_splits_large_pdf(self) -> None:
"""Test that large PDFs are split into chunks."""
pdf_bytes = create_test_pdf(num_pages=10)
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="large.pdf"))
result = list(chunk_pdf(pdf, max_pages=3))
assert len(result) == 4
assert all(isinstance(chunk, PDFFile) for chunk in result)
def test_no_chunk_if_within_limit(self) -> None:
"""Test that small PDFs are returned unchanged."""
pdf_bytes = create_test_pdf(num_pages=3)
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="small.pdf"))
result = list(chunk_pdf(pdf, max_pages=5))
assert len(result) == 1
assert result[0] is pdf
def test_chunk_filenames(self) -> None:
"""Test that chunked files have indexed filenames."""
pdf_bytes = create_test_pdf(num_pages=6)
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="document.pdf"))
result = list(chunk_pdf(pdf, max_pages=2))
assert result[0].filename == "document_chunk_0.pdf"
assert result[1].filename == "document_chunk_1.pdf"
assert result[2].filename == "document_chunk_2.pdf"
def test_chunk_with_overlap(self) -> None:
"""Test chunking with overlapping pages."""
pdf_bytes = create_test_pdf(num_pages=10)
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="doc.pdf"))
result = list(chunk_pdf(pdf, max_pages=4, overlap_pages=1))
# With overlap, we get more chunks
assert len(result) >= 3
def test_chunk_page_counts(self) -> None:
"""Test that each chunk has correct page count."""
pdf_bytes = create_test_pdf(num_pages=7)
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="doc.pdf"))
result = list(chunk_pdf(pdf, max_pages=3))
page_counts = [get_pdf_page_count(chunk) for chunk in result]
assert page_counts == [3, 3, 1]
class TestChunkText:
"""Tests for chunk_text function."""
def test_chunk_splits_large_text(self) -> None:
"""Test that large text files are split into chunks."""
content = "Hello world. " * 100
text = TextFile(source=content.encode(), filename="large.txt")
result = list(chunk_text(text, max_chars=200, overlap_chars=0))
assert len(result) > 1
assert all(isinstance(chunk, TextFile) for chunk in result)
def test_no_chunk_if_within_limit(self) -> None:
"""Test that small text files are returned unchanged."""
content = "Short text"
text = TextFile(source=content.encode(), filename="small.txt")
result = list(chunk_text(text, max_chars=1000, overlap_chars=0))
assert len(result) == 1
assert result[0] is text
def test_chunk_filenames(self) -> None:
"""Test that chunked files have indexed filenames."""
content = "A" * 500
text = TextFile(source=FileBytes(data=content.encode(), filename="data.txt"))
result = list(chunk_text(text, max_chars=200, overlap_chars=0))
assert result[0].filename == "data_chunk_0.txt"
assert result[1].filename == "data_chunk_1.txt"
assert len(result) == 3
def test_chunk_preserves_extension(self) -> None:
"""Test that file extension is preserved in chunks."""
content = "A" * 500
text = TextFile(source=FileBytes(data=content.encode(), filename="script.py"))
result = list(chunk_text(text, max_chars=200, overlap_chars=0))
assert all(chunk.filename.endswith(".py") for chunk in result)
def test_chunk_prefers_newline_boundaries(self) -> None:
"""Test that chunking prefers to split at newlines."""
content = "Line one\nLine two\nLine three\nLine four\nLine five"
text = TextFile(source=content.encode(), filename="lines.txt")
result = list(chunk_text(text, max_chars=25, overlap_chars=0, split_on_newlines=True))
# Should split at newline boundaries
for chunk in result:
chunk_text_content = chunk.read().decode()
# Chunks should end at newlines (except possibly the last)
if chunk != result[-1]:
assert chunk_text_content.endswith("\n") or len(chunk_text_content) <= 25
def test_chunk_with_overlap(self) -> None:
"""Test chunking with overlapping characters."""
content = "ABCDEFGHIJ" * 10
text = TextFile(source=content.encode(), filename="data.txt")
result = list(chunk_text(text, max_chars=30, overlap_chars=5))
# With overlap, chunks should share some content
assert len(result) >= 3
def test_chunk_overlap_larger_than_max_chars(self) -> None:
"""Test that overlap > max_chars doesn't cause infinite loop."""
content = "A" * 100
text = TextFile(source=content.encode(), filename="data.txt")
# overlap_chars > max_chars should still work (just with max overlap)
result = list(chunk_text(text, max_chars=20, overlap_chars=50))
assert len(result) > 1
# Should still complete without hanging
class TestGetImageDimensions:
"""Tests for get_image_dimensions function."""
def test_get_dimensions(self) -> None:
"""Test getting image dimensions."""
png_bytes = create_test_png(150, 100)
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
dims = get_image_dimensions(img)
assert dims == (150, 100)
def test_returns_none_for_invalid_image(self) -> None:
"""Test that None is returned for invalid image data."""
img = ImageFile(source=FileBytes(data=b"not an image", filename="bad.png"))
dims = get_image_dimensions(img)
assert dims is None
def test_returns_none_without_pillow(self) -> None:
"""Test that None is returned when Pillow is not installed."""
png_bytes = create_test_png(100, 100)
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
with patch.dict("sys.modules", {"PIL": None}):
# Can't easily test this without unloading module
# Just verify the function handles the case gracefully
pass
class TestGetPdfPageCount:
"""Tests for get_pdf_page_count function."""
def test_get_page_count(self) -> None:
"""Test getting PDF page count."""
pdf_bytes = create_test_pdf(num_pages=5)
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="test.pdf"))
count = get_pdf_page_count(pdf)
assert count == 5
def test_single_page(self) -> None:
"""Test page count for single page PDF."""
pdf_bytes = create_test_pdf(num_pages=1)
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="single.pdf"))
count = get_pdf_page_count(pdf)
assert count == 1
def test_returns_none_for_invalid_pdf(self) -> None:
"""Test that None is returned for invalid PDF data."""
pdf = PDFFile(source=FileBytes(data=b"not a pdf", filename="bad.pdf"))
count = get_pdf_page_count(pdf)
assert count is None

View File

@@ -1,575 +0,0 @@
"""Tests for file validators."""
from unittest.mock import patch
import pytest
from crewai.files import AudioFile, FileBytes, ImageFile, PDFFile, TextFile, VideoFile
from crewai.files.processing.constraints import (
ANTHROPIC_CONSTRAINTS,
AudioConstraints,
ImageConstraints,
PDFConstraints,
ProviderConstraints,
VideoConstraints,
)
from crewai.files.processing.exceptions import (
FileTooLargeError,
FileValidationError,
UnsupportedFileTypeError,
)
from crewai.files.processing.validators import (
_get_audio_duration,
_get_video_duration,
validate_audio,
validate_file,
validate_image,
validate_pdf,
validate_text,
validate_video,
)
# Minimal valid PNG: 8x8 pixel RGB image (valid for PIL)
MINIMAL_PNG = bytes([
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d,
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08,
0x08, 0x02, 0x00, 0x00, 0x00, 0x4b, 0x6d, 0x29, 0xdc, 0x00, 0x00, 0x00,
0x12, 0x49, 0x44, 0x41, 0x54, 0x78, 0x9c, 0x63, 0xfc, 0xcf, 0x80, 0x1d,
0x30, 0xe1, 0x10, 0x1f, 0xa4, 0x12, 0x00, 0xcd, 0x41, 0x01, 0x0f, 0xe8,
0x41, 0xe2, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae,
0x42, 0x60, 0x82,
])
# Minimal valid PDF
MINIMAL_PDF = (
b"%PDF-1.4\n1 0 obj<</Type/Catalog/Pages 2 0 R>>endobj "
b"2 0 obj<</Type/Pages/Kids[3 0 R]/Count 1>>endobj "
b"3 0 obj<</Type/Page/MediaBox[0 0 612 792]/Parent 2 0 R>>endobj "
b"xref\n0 4\n0000000000 65535 f \n0000000009 00000 n \n"
b"0000000052 00000 n \n0000000101 00000 n \n"
b"trailer<</Size 4/Root 1 0 R>>\nstartxref\n178\n%%EOF"
)
class TestValidateImage:
"""Tests for validate_image function."""
def test_validate_valid_image(self):
"""Test validating a valid image within constraints."""
constraints = ImageConstraints(
max_size_bytes=10 * 1024 * 1024,
supported_formats=("image/png",),
)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
errors = validate_image(file, constraints, raise_on_error=False)
assert len(errors) == 0
def test_validate_image_too_large(self):
"""Test validating an image that exceeds size limit."""
constraints = ImageConstraints(
max_size_bytes=10, # Very small limit
supported_formats=("image/png",),
)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
with pytest.raises(FileTooLargeError) as exc_info:
validate_image(file, constraints)
assert "exceeds" in str(exc_info.value)
assert exc_info.value.file_name == "test.png"
def test_validate_image_unsupported_format(self):
"""Test validating an image with unsupported format."""
constraints = ImageConstraints(
max_size_bytes=10 * 1024 * 1024,
supported_formats=("image/jpeg",), # Only JPEG
)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
with pytest.raises(UnsupportedFileTypeError) as exc_info:
validate_image(file, constraints)
assert "not supported" in str(exc_info.value)
def test_validate_image_no_raise(self):
"""Test validating with raise_on_error=False returns errors list."""
constraints = ImageConstraints(
max_size_bytes=10,
supported_formats=("image/jpeg",),
)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
errors = validate_image(file, constraints, raise_on_error=False)
assert len(errors) == 2 # Size error and format error
class TestValidatePDF:
"""Tests for validate_pdf function."""
def test_validate_valid_pdf(self):
"""Test validating a valid PDF within constraints."""
constraints = PDFConstraints(
max_size_bytes=10 * 1024 * 1024,
)
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
errors = validate_pdf(file, constraints, raise_on_error=False)
assert len(errors) == 0
def test_validate_pdf_too_large(self):
"""Test validating a PDF that exceeds size limit."""
constraints = PDFConstraints(
max_size_bytes=10, # Very small limit
)
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
with pytest.raises(FileTooLargeError) as exc_info:
validate_pdf(file, constraints)
assert "exceeds" in str(exc_info.value)
class TestValidateText:
"""Tests for validate_text function."""
def test_validate_valid_text(self):
"""Test validating a valid text file."""
constraints = ProviderConstraints(
name="test",
general_max_size_bytes=10 * 1024 * 1024,
)
file = TextFile(source=FileBytes(data=b"Hello, World!", filename="test.txt"))
errors = validate_text(file, constraints, raise_on_error=False)
assert len(errors) == 0
def test_validate_text_too_large(self):
"""Test validating text that exceeds size limit."""
constraints = ProviderConstraints(
name="test",
general_max_size_bytes=5,
)
file = TextFile(source=FileBytes(data=b"Hello, World!", filename="test.txt"))
with pytest.raises(FileTooLargeError):
validate_text(file, constraints)
def test_validate_text_no_limit(self):
"""Test validating text with no size limit."""
constraints = ProviderConstraints(name="test")
file = TextFile(source=FileBytes(data=b"Hello, World!", filename="test.txt"))
errors = validate_text(file, constraints, raise_on_error=False)
assert len(errors) == 0
class TestValidateFile:
"""Tests for validate_file function."""
def test_validate_file_dispatches_to_image(self):
"""Test validate_file dispatches to image validator."""
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
errors = validate_file(file, ANTHROPIC_CONSTRAINTS, raise_on_error=False)
assert len(errors) == 0
def test_validate_file_dispatches_to_pdf(self):
"""Test validate_file dispatches to PDF validator."""
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
errors = validate_file(file, ANTHROPIC_CONSTRAINTS, raise_on_error=False)
assert len(errors) == 0
def test_validate_file_unsupported_type(self):
"""Test validating a file type not supported by provider."""
constraints = ProviderConstraints(
name="test",
image=None, # No image support
)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
with pytest.raises(UnsupportedFileTypeError) as exc_info:
validate_file(file, constraints)
assert "does not support images" in str(exc_info.value)
def test_validate_file_pdf_not_supported(self):
"""Test validating PDF when provider doesn't support it."""
constraints = ProviderConstraints(
name="test",
pdf=None, # No PDF support
)
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
with pytest.raises(UnsupportedFileTypeError) as exc_info:
validate_file(file, constraints)
assert "does not support PDFs" in str(exc_info.value)
# Minimal audio bytes for testing (not a valid audio file, used for mocked tests)
MINIMAL_AUDIO = b"\x00" * 100
# Minimal video bytes for testing (not a valid video file, used for mocked tests)
MINIMAL_VIDEO = b"\x00" * 100
# Fallback content type when python-magic cannot detect
FALLBACK_CONTENT_TYPE = "application/octet-stream"
class TestValidateAudio:
"""Tests for validate_audio function and audio duration validation."""
def test_validate_valid_audio(self):
"""Test validating a valid audio file within constraints."""
constraints = AudioConstraints(
max_size_bytes=10 * 1024 * 1024,
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
)
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
errors = validate_audio(file, constraints, raise_on_error=False)
assert len(errors) == 0
def test_validate_audio_too_large(self):
"""Test validating an audio file that exceeds size limit."""
constraints = AudioConstraints(
max_size_bytes=10, # Very small limit
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
)
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
with pytest.raises(FileTooLargeError) as exc_info:
validate_audio(file, constraints)
assert "exceeds" in str(exc_info.value)
assert exc_info.value.file_name == "test.mp3"
def test_validate_audio_unsupported_format(self):
"""Test validating an audio file with unsupported format."""
constraints = AudioConstraints(
max_size_bytes=10 * 1024 * 1024,
supported_formats=("audio/wav",), # Only WAV
)
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
with pytest.raises(UnsupportedFileTypeError) as exc_info:
validate_audio(file, constraints)
assert "not supported" in str(exc_info.value)
@patch("crewai.files.processing.validators._get_audio_duration")
def test_validate_audio_duration_passes(self, mock_get_duration):
"""Test validating audio when duration is under limit."""
mock_get_duration.return_value = 30.0
constraints = AudioConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=60,
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
)
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
errors = validate_audio(file, constraints, raise_on_error=False)
assert len(errors) == 0
mock_get_duration.assert_called_once()
@patch("crewai.files.processing.validators._get_audio_duration")
def test_validate_audio_duration_fails(self, mock_get_duration):
"""Test validating audio when duration exceeds limit."""
mock_get_duration.return_value = 120.5
constraints = AudioConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=60,
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
)
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
with pytest.raises(FileValidationError) as exc_info:
validate_audio(file, constraints)
assert "duration" in str(exc_info.value).lower()
assert "120.5s" in str(exc_info.value)
assert "60s" in str(exc_info.value)
@patch("crewai.files.processing.validators._get_audio_duration")
def test_validate_audio_duration_no_raise(self, mock_get_duration):
"""Test audio duration validation with raise_on_error=False."""
mock_get_duration.return_value = 120.5
constraints = AudioConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=60,
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
)
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
errors = validate_audio(file, constraints, raise_on_error=False)
assert len(errors) == 1
assert "duration" in errors[0].lower()
@patch("crewai.files.processing.validators._get_audio_duration")
def test_validate_audio_duration_none_skips(self, mock_get_duration):
"""Test that duration validation is skipped when max_duration_seconds is None."""
constraints = AudioConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=None,
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
)
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
errors = validate_audio(file, constraints, raise_on_error=False)
assert len(errors) == 0
mock_get_duration.assert_not_called()
@patch("crewai.files.processing.validators._get_audio_duration")
def test_validate_audio_duration_detection_returns_none(self, mock_get_duration):
"""Test that validation passes when duration detection returns None."""
mock_get_duration.return_value = None
constraints = AudioConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=60,
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
)
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
errors = validate_audio(file, constraints, raise_on_error=False)
assert len(errors) == 0
class TestValidateVideo:
"""Tests for validate_video function and video duration validation."""
def test_validate_valid_video(self):
"""Test validating a valid video file within constraints."""
constraints = VideoConstraints(
max_size_bytes=10 * 1024 * 1024,
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
)
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
errors = validate_video(file, constraints, raise_on_error=False)
assert len(errors) == 0
def test_validate_video_too_large(self):
"""Test validating a video file that exceeds size limit."""
constraints = VideoConstraints(
max_size_bytes=10, # Very small limit
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
)
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
with pytest.raises(FileTooLargeError) as exc_info:
validate_video(file, constraints)
assert "exceeds" in str(exc_info.value)
assert exc_info.value.file_name == "test.mp4"
def test_validate_video_unsupported_format(self):
"""Test validating a video file with unsupported format."""
constraints = VideoConstraints(
max_size_bytes=10 * 1024 * 1024,
supported_formats=("video/webm",), # Only WebM
)
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
with pytest.raises(UnsupportedFileTypeError) as exc_info:
validate_video(file, constraints)
assert "not supported" in str(exc_info.value)
@patch("crewai.files.processing.validators._get_video_duration")
def test_validate_video_duration_passes(self, mock_get_duration):
"""Test validating video when duration is under limit."""
mock_get_duration.return_value = 30.0
constraints = VideoConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=60,
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
)
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
errors = validate_video(file, constraints, raise_on_error=False)
assert len(errors) == 0
mock_get_duration.assert_called_once()
@patch("crewai.files.processing.validators._get_video_duration")
def test_validate_video_duration_fails(self, mock_get_duration):
"""Test validating video when duration exceeds limit."""
mock_get_duration.return_value = 180.0
constraints = VideoConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=60,
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
)
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
with pytest.raises(FileValidationError) as exc_info:
validate_video(file, constraints)
assert "duration" in str(exc_info.value).lower()
assert "180.0s" in str(exc_info.value)
assert "60s" in str(exc_info.value)
@patch("crewai.files.processing.validators._get_video_duration")
def test_validate_video_duration_no_raise(self, mock_get_duration):
"""Test video duration validation with raise_on_error=False."""
mock_get_duration.return_value = 180.0
constraints = VideoConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=60,
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
)
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
errors = validate_video(file, constraints, raise_on_error=False)
assert len(errors) == 1
assert "duration" in errors[0].lower()
@patch("crewai.files.processing.validators._get_video_duration")
def test_validate_video_duration_none_skips(self, mock_get_duration):
"""Test that duration validation is skipped when max_duration_seconds is None."""
constraints = VideoConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=None,
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
)
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
errors = validate_video(file, constraints, raise_on_error=False)
assert len(errors) == 0
mock_get_duration.assert_not_called()
@patch("crewai.files.processing.validators._get_video_duration")
def test_validate_video_duration_detection_returns_none(self, mock_get_duration):
"""Test that validation passes when duration detection returns None."""
mock_get_duration.return_value = None
constraints = VideoConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=60,
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
)
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
errors = validate_video(file, constraints, raise_on_error=False)
assert len(errors) == 0
class TestGetAudioDuration:
"""Tests for _get_audio_duration helper function."""
def test_get_audio_duration_corrupt_file(self):
"""Test handling of corrupt audio data."""
corrupt_data = b"not valid audio data at all"
result = _get_audio_duration(corrupt_data)
assert result is None
class TestGetVideoDuration:
"""Tests for _get_video_duration helper function."""
def test_get_video_duration_corrupt_file(self):
"""Test handling of corrupt video data."""
corrupt_data = b"not valid video data at all"
result = _get_video_duration(corrupt_data)
assert result is None
class TestRealVideoFile:
"""Tests using real video fixture file."""
@pytest.fixture
def sample_video_path(self):
"""Path to sample video fixture."""
from pathlib import Path
path = Path(__file__).parent.parent.parent / "fixtures" / "sample_video.mp4"
if not path.exists():
pytest.skip("sample_video.mp4 fixture not found")
return path
@pytest.fixture
def sample_video_content(self, sample_video_path):
"""Read sample video content."""
return sample_video_path.read_bytes()
def test_get_video_duration_real_file(self, sample_video_content):
"""Test duration detection with real video file."""
try:
import av # noqa: F401
except ImportError:
pytest.skip("PyAV not installed")
duration = _get_video_duration(sample_video_content, "video/mp4")
assert duration is not None
assert 4.5 <= duration <= 5.5 # ~5 seconds with tolerance
def test_get_video_duration_real_file_no_format_hint(self, sample_video_content):
"""Test duration detection without format hint."""
try:
import av # noqa: F401
except ImportError:
pytest.skip("PyAV not installed")
duration = _get_video_duration(sample_video_content)
assert duration is not None
assert 4.5 <= duration <= 5.5
def test_validate_video_real_file_passes(self, sample_video_path):
"""Test validating real video file within constraints."""
try:
import av # noqa: F401
except ImportError:
pytest.skip("PyAV not installed")
constraints = VideoConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=60,
supported_formats=("video/mp4",),
)
file = VideoFile(source=str(sample_video_path))
errors = validate_video(file, constraints, raise_on_error=False)
assert len(errors) == 0
def test_validate_video_real_file_duration_exceeded(self, sample_video_path):
"""Test validating real video file that exceeds duration limit."""
try:
import av # noqa: F401
except ImportError:
pytest.skip("PyAV not installed")
constraints = VideoConstraints(
max_size_bytes=10 * 1024 * 1024,
max_duration_seconds=2, # Video is ~5 seconds
supported_formats=("video/mp4",),
)
file = VideoFile(source=str(sample_video_path))
with pytest.raises(FileValidationError) as exc_info:
validate_video(file, constraints)
assert "duration" in str(exc_info.value).lower()
assert "2s" in str(exc_info.value)

View File

@@ -1,312 +0,0 @@
"""Tests for FileUrl source type and URL resolution."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from crewai.files import FileBytes, FileUrl, ImageFile
from crewai.files.file import _normalize_source, FilePath
from crewai.files.resolved import InlineBase64, UrlReference
from crewai.files.resolver import FileResolver
class TestFileUrl:
"""Tests for FileUrl source type."""
def test_create_file_url(self):
"""Test creating FileUrl with valid URL."""
url = FileUrl(url="https://example.com/image.png")
assert url.url == "https://example.com/image.png"
assert url.filename is None
def test_create_file_url_with_filename(self):
"""Test creating FileUrl with custom filename."""
url = FileUrl(url="https://example.com/image.png", filename="custom.png")
assert url.url == "https://example.com/image.png"
assert url.filename == "custom.png"
def test_invalid_url_scheme_raises(self):
"""Test that non-http(s) URLs raise ValueError."""
with pytest.raises(ValueError, match="Invalid URL scheme"):
FileUrl(url="ftp://example.com/file.txt")
def test_invalid_url_scheme_file_raises(self):
"""Test that file:// URLs raise ValueError."""
with pytest.raises(ValueError, match="Invalid URL scheme"):
FileUrl(url="file:///path/to/file.txt")
def test_http_url_valid(self):
"""Test that HTTP URLs are valid."""
url = FileUrl(url="http://example.com/image.jpg")
assert url.url == "http://example.com/image.jpg"
def test_https_url_valid(self):
"""Test that HTTPS URLs are valid."""
url = FileUrl(url="https://example.com/image.jpg")
assert url.url == "https://example.com/image.jpg"
def test_content_type_guessing_png(self):
"""Test content type guessing for PNG files."""
url = FileUrl(url="https://example.com/image.png")
assert url.content_type == "image/png"
def test_content_type_guessing_jpeg(self):
"""Test content type guessing for JPEG files."""
url = FileUrl(url="https://example.com/photo.jpg")
assert url.content_type == "image/jpeg"
def test_content_type_guessing_pdf(self):
"""Test content type guessing for PDF files."""
url = FileUrl(url="https://example.com/document.pdf")
assert url.content_type == "application/pdf"
def test_content_type_guessing_with_query_params(self):
"""Test content type guessing with URL query parameters."""
url = FileUrl(url="https://example.com/image.png?v=123&token=abc")
assert url.content_type == "image/png"
def test_content_type_fallback_unknown(self):
"""Test content type falls back to octet-stream for unknown extensions."""
url = FileUrl(url="https://example.com/file.unknownext123")
assert url.content_type == "application/octet-stream"
def test_content_type_no_extension(self):
"""Test content type for URL without extension."""
url = FileUrl(url="https://example.com/file")
assert url.content_type == "application/octet-stream"
def test_read_fetches_content(self):
"""Test that read() fetches content from URL."""
url = FileUrl(url="https://example.com/image.png")
mock_response = MagicMock()
mock_response.content = b"fake image content"
mock_response.headers = {"content-type": "image/png"}
with patch("httpx.get", return_value=mock_response) as mock_get:
content = url.read()
mock_get.assert_called_once_with(
"https://example.com/image.png", follow_redirects=True
)
assert content == b"fake image content"
def test_read_caches_content(self):
"""Test that read() caches content."""
url = FileUrl(url="https://example.com/image.png")
mock_response = MagicMock()
mock_response.content = b"fake content"
mock_response.headers = {}
with patch("httpx.get", return_value=mock_response) as mock_get:
content1 = url.read()
content2 = url.read()
mock_get.assert_called_once()
assert content1 == content2
def test_read_updates_content_type_from_response(self):
"""Test that read() updates content type from response headers."""
url = FileUrl(url="https://example.com/file")
mock_response = MagicMock()
mock_response.content = b"fake content"
mock_response.headers = {"content-type": "image/webp; charset=utf-8"}
with patch("httpx.get", return_value=mock_response):
url.read()
assert url.content_type == "image/webp"
@pytest.mark.asyncio
async def test_aread_fetches_content(self):
"""Test that aread() fetches content from URL asynchronously."""
url = FileUrl(url="https://example.com/image.png")
mock_response = MagicMock()
mock_response.content = b"async fake content"
mock_response.headers = {"content-type": "image/png"}
mock_response.raise_for_status = MagicMock()
mock_client = MagicMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
with patch("httpx.AsyncClient", return_value=mock_client):
content = await url.aread()
assert content == b"async fake content"
@pytest.mark.asyncio
async def test_aread_caches_content(self):
"""Test that aread() caches content."""
url = FileUrl(url="https://example.com/image.png")
mock_response = MagicMock()
mock_response.content = b"cached content"
mock_response.headers = {}
mock_response.raise_for_status = MagicMock()
mock_client = MagicMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
with patch("httpx.AsyncClient", return_value=mock_client):
content1 = await url.aread()
content2 = await url.aread()
mock_client.get.assert_called_once()
assert content1 == content2
class TestNormalizeSource:
"""Tests for _normalize_source with URL detection."""
def test_normalize_url_string(self):
"""Test that URL strings are converted to FileUrl."""
result = _normalize_source("https://example.com/image.png")
assert isinstance(result, FileUrl)
assert result.url == "https://example.com/image.png"
def test_normalize_http_url_string(self):
"""Test that HTTP URL strings are converted to FileUrl."""
result = _normalize_source("http://example.com/file.pdf")
assert isinstance(result, FileUrl)
assert result.url == "http://example.com/file.pdf"
def test_normalize_file_path_string(self, tmp_path):
"""Test that file path strings are converted to FilePath."""
test_file = tmp_path / "test.png"
test_file.write_bytes(b"test content")
result = _normalize_source(str(test_file))
assert isinstance(result, FilePath)
def test_normalize_relative_path_is_not_url(self):
"""Test that relative path strings are not treated as URLs."""
result = _normalize_source("https://example.com/file.png")
assert isinstance(result, FileUrl)
assert not isinstance(result, FilePath)
def test_normalize_file_url_passthrough(self):
"""Test that FileUrl instances pass through unchanged."""
original = FileUrl(url="https://example.com/image.png")
result = _normalize_source(original)
assert result is original
class TestResolverUrlHandling:
"""Tests for FileResolver URL handling."""
def test_resolve_url_source_for_supported_provider(self):
"""Test URL source resolves to UrlReference for supported providers."""
resolver = FileResolver()
file = ImageFile(source=FileUrl(url="https://example.com/image.png"))
resolved = resolver.resolve(file, "anthropic")
assert isinstance(resolved, UrlReference)
assert resolved.url == "https://example.com/image.png"
assert resolved.content_type == "image/png"
def test_resolve_url_source_openai(self):
"""Test URL source resolves to UrlReference for OpenAI."""
resolver = FileResolver()
file = ImageFile(source=FileUrl(url="https://example.com/photo.jpg"))
resolved = resolver.resolve(file, "openai")
assert isinstance(resolved, UrlReference)
assert resolved.url == "https://example.com/photo.jpg"
def test_resolve_url_source_gemini(self):
"""Test URL source resolves to UrlReference for Gemini."""
resolver = FileResolver()
file = ImageFile(source=FileUrl(url="https://example.com/image.webp"))
resolved = resolver.resolve(file, "gemini")
assert isinstance(resolved, UrlReference)
assert resolved.url == "https://example.com/image.webp"
def test_resolve_url_source_azure(self):
"""Test URL source resolves to UrlReference for Azure."""
resolver = FileResolver()
file = ImageFile(source=FileUrl(url="https://example.com/image.gif"))
resolved = resolver.resolve(file, "azure")
assert isinstance(resolved, UrlReference)
assert resolved.url == "https://example.com/image.gif"
def test_resolve_url_source_bedrock_fetches_content(self):
"""Test URL source fetches content for Bedrock (unsupported URLs)."""
resolver = FileResolver()
file_url = FileUrl(url="https://example.com/image.png")
file = ImageFile(source=file_url)
mock_response = MagicMock()
mock_response.content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 50
mock_response.headers = {"content-type": "image/png"}
with patch("httpx.get", return_value=mock_response):
resolved = resolver.resolve(file, "bedrock")
assert not isinstance(resolved, UrlReference)
def test_resolve_bytes_source_still_works(self):
"""Test that bytes source still resolves normally."""
resolver = FileResolver()
minimal_png = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
)
file = ImageFile(source=FileBytes(data=minimal_png, filename="test.png"))
resolved = resolver.resolve(file, "anthropic")
assert isinstance(resolved, InlineBase64)
@pytest.mark.asyncio
async def test_aresolve_url_source(self):
"""Test async URL resolution for supported provider."""
resolver = FileResolver()
file = ImageFile(source=FileUrl(url="https://example.com/image.png"))
resolved = await resolver.aresolve(file, "anthropic")
assert isinstance(resolved, UrlReference)
assert resolved.url == "https://example.com/image.png"
class TestImageFileWithUrl:
"""Tests for creating ImageFile with URL source."""
def test_image_file_from_url_string(self):
"""Test creating ImageFile from URL string."""
file = ImageFile(source="https://example.com/image.png")
assert isinstance(file.source, FileUrl)
assert file.source.url == "https://example.com/image.png"
def test_image_file_from_file_url(self):
"""Test creating ImageFile from FileUrl instance."""
url = FileUrl(url="https://example.com/photo.jpg")
file = ImageFile(source=url)
assert file.source is url
assert file.content_type == "image/jpeg"

View File

@@ -1,135 +0,0 @@
"""Tests for resolved file types."""
from datetime import datetime, timezone
import pytest
from crewai.files.resolved import (
FileReference,
InlineBase64,
InlineBytes,
ResolvedFile,
UrlReference,
)
class TestInlineBase64:
"""Tests for InlineBase64 resolved type."""
def test_create_inline_base64(self):
"""Test creating InlineBase64 instance."""
resolved = InlineBase64(
content_type="image/png",
data="iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
)
assert resolved.content_type == "image/png"
assert len(resolved.data) > 0
def test_inline_base64_is_resolved_file(self):
"""Test InlineBase64 is a ResolvedFile."""
resolved = InlineBase64(content_type="image/png", data="abc123")
assert isinstance(resolved, ResolvedFile)
def test_inline_base64_frozen(self):
"""Test InlineBase64 is immutable."""
resolved = InlineBase64(content_type="image/png", data="abc123")
with pytest.raises(Exception):
resolved.data = "xyz789"
class TestInlineBytes:
"""Tests for InlineBytes resolved type."""
def test_create_inline_bytes(self):
"""Test creating InlineBytes instance."""
data = b"\x89PNG\r\n\x1a\n"
resolved = InlineBytes(
content_type="image/png",
data=data,
)
assert resolved.content_type == "image/png"
assert resolved.data == data
def test_inline_bytes_is_resolved_file(self):
"""Test InlineBytes is a ResolvedFile."""
resolved = InlineBytes(content_type="image/png", data=b"test")
assert isinstance(resolved, ResolvedFile)
class TestFileReference:
"""Tests for FileReference resolved type."""
def test_create_file_reference(self):
"""Test creating FileReference instance."""
resolved = FileReference(
content_type="image/png",
file_id="file-abc123",
provider="gemini",
)
assert resolved.content_type == "image/png"
assert resolved.file_id == "file-abc123"
assert resolved.provider == "gemini"
assert resolved.expires_at is None
assert resolved.file_uri is None
def test_file_reference_with_expiry(self):
"""Test FileReference with expiry time."""
expiry = datetime.now(timezone.utc)
resolved = FileReference(
content_type="application/pdf",
file_id="file-xyz789",
provider="gemini",
expires_at=expiry,
)
assert resolved.expires_at == expiry
def test_file_reference_with_uri(self):
"""Test FileReference with URI."""
resolved = FileReference(
content_type="video/mp4",
file_id="file-video123",
provider="gemini",
file_uri="https://generativelanguage.googleapis.com/v1/files/file-video123",
)
assert resolved.file_uri is not None
def test_file_reference_is_resolved_file(self):
"""Test FileReference is a ResolvedFile."""
resolved = FileReference(
content_type="image/png",
file_id="file-123",
provider="anthropic",
)
assert isinstance(resolved, ResolvedFile)
class TestUrlReference:
"""Tests for UrlReference resolved type."""
def test_create_url_reference(self):
"""Test creating UrlReference instance."""
resolved = UrlReference(
content_type="image/png",
url="https://storage.googleapis.com/bucket/image.png",
)
assert resolved.content_type == "image/png"
assert resolved.url == "https://storage.googleapis.com/bucket/image.png"
def test_url_reference_is_resolved_file(self):
"""Test UrlReference is a ResolvedFile."""
resolved = UrlReference(
content_type="image/jpeg",
url="https://example.com/photo.jpg",
)
assert isinstance(resolved, ResolvedFile)

View File

@@ -1,174 +0,0 @@
"""Tests for FileResolver."""
import pytest
from crewai.files import FileBytes, ImageFile
from crewai.files.resolved import InlineBase64, InlineBytes
from crewai.files.resolver import (
FileResolver,
FileResolverConfig,
create_resolver,
)
from crewai.files.upload_cache import UploadCache
# Minimal valid PNG
MINIMAL_PNG = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
)
class TestFileResolverConfig:
"""Tests for FileResolverConfig."""
def test_default_config(self):
"""Test default configuration values."""
config = FileResolverConfig()
assert config.prefer_upload is False
assert config.upload_threshold_bytes is None
assert config.use_bytes_for_bedrock is True
def test_custom_config(self):
"""Test custom configuration values."""
config = FileResolverConfig(
prefer_upload=True,
upload_threshold_bytes=1024 * 1024,
use_bytes_for_bedrock=False,
)
assert config.prefer_upload is True
assert config.upload_threshold_bytes == 1024 * 1024
assert config.use_bytes_for_bedrock is False
class TestFileResolver:
"""Tests for FileResolver class."""
def test_resolve_inline_base64(self):
"""Test resolving file as inline base64."""
resolver = FileResolver()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
resolved = resolver.resolve(file, "openai")
assert isinstance(resolved, InlineBase64)
assert resolved.content_type == "image/png"
assert len(resolved.data) > 0
def test_resolve_inline_bytes_for_bedrock(self):
"""Test resolving file as inline bytes for Bedrock."""
config = FileResolverConfig(use_bytes_for_bedrock=True)
resolver = FileResolver(config=config)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
resolved = resolver.resolve(file, "bedrock")
assert isinstance(resolved, InlineBytes)
assert resolved.content_type == "image/png"
assert resolved.data == MINIMAL_PNG
def test_resolve_files_multiple(self):
"""Test resolving multiple files."""
resolver = FileResolver()
files = {
"image1": ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png")),
"image2": ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test2.png")),
}
resolved = resolver.resolve_files(files, "openai")
assert len(resolved) == 2
assert "image1" in resolved
assert "image2" in resolved
assert all(isinstance(r, InlineBase64) for r in resolved.values())
def test_resolve_with_cache(self):
"""Test resolver uses cache."""
cache = UploadCache()
resolver = FileResolver(upload_cache=cache)
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
# First resolution
resolved1 = resolver.resolve(file, "openai")
# Second resolution (should use same base64 encoding)
resolved2 = resolver.resolve(file, "openai")
assert isinstance(resolved1, InlineBase64)
assert isinstance(resolved2, InlineBase64)
# Data should be identical
assert resolved1.data == resolved2.data
def test_clear_cache(self):
"""Test clearing resolver cache."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
# Add something to cache manually
cache.set(file=file, provider="gemini", file_id="test")
resolver = FileResolver(upload_cache=cache)
resolver.clear_cache()
assert len(cache) == 0
def test_get_cached_uploads(self):
"""Test getting cached uploads from resolver."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="test-1")
cache.set(file=file, provider="anthropic", file_id="test-2")
resolver = FileResolver(upload_cache=cache)
gemini_uploads = resolver.get_cached_uploads("gemini")
anthropic_uploads = resolver.get_cached_uploads("anthropic")
assert len(gemini_uploads) == 1
assert len(anthropic_uploads) == 1
def test_get_cached_uploads_empty(self):
"""Test getting cached uploads when no cache."""
resolver = FileResolver() # No cache
uploads = resolver.get_cached_uploads("gemini")
assert uploads == []
class TestCreateResolver:
"""Tests for create_resolver factory function."""
def test_create_default_resolver(self):
"""Test creating resolver with default settings."""
resolver = create_resolver()
assert resolver.config.prefer_upload is False
assert resolver.upload_cache is not None
def test_create_resolver_with_options(self):
"""Test creating resolver with custom options."""
resolver = create_resolver(
prefer_upload=True,
upload_threshold_bytes=5 * 1024 * 1024,
enable_cache=False,
)
assert resolver.config.prefer_upload is True
assert resolver.config.upload_threshold_bytes == 5 * 1024 * 1024
assert resolver.upload_cache is None
def test_create_resolver_cache_enabled(self):
"""Test resolver has cache when enabled."""
resolver = create_resolver(enable_cache=True)
assert resolver.upload_cache is not None
def test_create_resolver_cache_disabled(self):
"""Test resolver has no cache when disabled."""
resolver = create_resolver(enable_cache=False)
assert resolver.upload_cache is None

View File

@@ -1,206 +0,0 @@
"""Tests for upload cache."""
from datetime import datetime, timedelta, timezone
import pytest
from crewai.files import FileBytes, ImageFile
from crewai.files.upload_cache import CachedUpload, UploadCache
# Minimal valid PNG
MINIMAL_PNG = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
)
class TestCachedUpload:
"""Tests for CachedUpload dataclass."""
def test_cached_upload_creation(self):
"""Test creating a cached upload."""
now = datetime.now(timezone.utc)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri="files/file-123",
content_type="image/png",
uploaded_at=now,
expires_at=now + timedelta(hours=48),
)
assert cached.file_id == "file-123"
assert cached.provider == "gemini"
assert cached.file_uri == "files/file-123"
assert cached.content_type == "image/png"
def test_is_expired_false(self):
"""Test is_expired returns False for non-expired upload."""
future = datetime.now(timezone.utc) + timedelta(hours=24)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc),
expires_at=future,
)
assert cached.is_expired() is False
def test_is_expired_true(self):
"""Test is_expired returns True for expired upload."""
past = datetime.now(timezone.utc) - timedelta(hours=1)
cached = CachedUpload(
file_id="file-123",
provider="gemini",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2),
expires_at=past,
)
assert cached.is_expired() is True
def test_is_expired_no_expiry(self):
"""Test is_expired returns False when no expiry set."""
cached = CachedUpload(
file_id="file-123",
provider="anthropic",
file_uri=None,
content_type="image/png",
uploaded_at=datetime.now(timezone.utc),
expires_at=None,
)
assert cached.is_expired() is False
class TestUploadCache:
"""Tests for UploadCache class."""
def test_cache_creation(self):
"""Test creating an empty cache."""
cache = UploadCache()
assert len(cache) == 0
def test_set_and_get(self):
"""Test setting and getting cached uploads."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cached = cache.set(
file=file,
provider="gemini",
file_id="file-123",
file_uri="files/file-123",
)
result = cache.get(file, "gemini")
assert result is not None
assert result.file_id == "file-123"
assert result.provider == "gemini"
def test_get_missing(self):
"""Test getting non-existent entry returns None."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
result = cache.get(file, "gemini")
assert result is None
def test_get_different_provider(self):
"""Test getting with different provider returns None."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
result = cache.get(file, "anthropic") # Different provider
assert result is None
def test_remove(self):
"""Test removing cached entry."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
removed = cache.remove(file, "gemini")
assert removed is True
assert cache.get(file, "gemini") is None
def test_remove_missing(self):
"""Test removing non-existent entry returns False."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
removed = cache.remove(file, "gemini")
assert removed is False
def test_remove_by_file_id(self):
"""Test removing by file ID."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
removed = cache.remove_by_file_id("file-123", "gemini")
assert removed is True
assert len(cache) == 0
def test_clear_expired(self):
"""Test clearing expired entries."""
cache = UploadCache()
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
# Add one expired and one valid entry
past = datetime.now(timezone.utc) - timedelta(hours=1)
future = datetime.now(timezone.utc) + timedelta(hours=24)
cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past)
cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future)
removed = cache.clear_expired()
assert removed == 1
assert len(cache) == 1
assert cache.get(file2, "gemini") is not None
def test_clear(self):
"""Test clearing all entries."""
cache = UploadCache()
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
cache.set(file=file, provider="gemini", file_id="file-123")
cache.set(file=file, provider="anthropic", file_id="file-456")
cleared = cache.clear()
assert cleared == 2
assert len(cache) == 0
def test_get_all_for_provider(self):
"""Test getting all cached uploads for a provider."""
cache = UploadCache()
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
file3 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png"))
cache.set(file=file1, provider="gemini", file_id="file-1")
cache.set(file=file2, provider="gemini", file_id="file-2")
cache.set(file=file3, provider="anthropic", file_id="file-3")
gemini_uploads = cache.get_all_for_provider("gemini")
anthropic_uploads = cache.get_all_for_provider("anthropic")
assert len(gemini_uploads) == 2
assert len(anthropic_uploads) == 1