mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 15:52:34 +00:00
refactor: extract files module to standalone crewai-files package
This commit is contained in:
@@ -1 +0,0 @@
|
||||
"""Tests for file processing utilities."""
|
||||
@@ -1 +0,0 @@
|
||||
"""Tests for file processing module."""
|
||||
@@ -1,226 +0,0 @@
|
||||
"""Tests for provider constraints."""
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.files.processing.constraints import (
|
||||
ANTHROPIC_CONSTRAINTS,
|
||||
BEDROCK_CONSTRAINTS,
|
||||
GEMINI_CONSTRAINTS,
|
||||
OPENAI_CONSTRAINTS,
|
||||
AudioConstraints,
|
||||
ImageConstraints,
|
||||
PDFConstraints,
|
||||
ProviderConstraints,
|
||||
VideoConstraints,
|
||||
get_constraints_for_provider,
|
||||
)
|
||||
|
||||
|
||||
class TestImageConstraints:
|
||||
"""Tests for ImageConstraints dataclass."""
|
||||
|
||||
def test_image_constraints_creation(self):
|
||||
"""Test creating image constraints with all fields."""
|
||||
constraints = ImageConstraints(
|
||||
max_size_bytes=5 * 1024 * 1024,
|
||||
max_width=8000,
|
||||
max_height=8000,
|
||||
max_images_per_request=10,
|
||||
)
|
||||
|
||||
assert constraints.max_size_bytes == 5 * 1024 * 1024
|
||||
assert constraints.max_width == 8000
|
||||
assert constraints.max_height == 8000
|
||||
assert constraints.max_images_per_request == 10
|
||||
|
||||
def test_image_constraints_defaults(self):
|
||||
"""Test image constraints with default values."""
|
||||
constraints = ImageConstraints(max_size_bytes=1000)
|
||||
|
||||
assert constraints.max_size_bytes == 1000
|
||||
assert constraints.max_width is None
|
||||
assert constraints.max_height is None
|
||||
assert constraints.max_images_per_request is None
|
||||
assert "image/png" in constraints.supported_formats
|
||||
|
||||
def test_image_constraints_frozen(self):
|
||||
"""Test that image constraints are immutable."""
|
||||
constraints = ImageConstraints(max_size_bytes=1000)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
constraints.max_size_bytes = 2000
|
||||
|
||||
|
||||
class TestPDFConstraints:
|
||||
"""Tests for PDFConstraints dataclass."""
|
||||
|
||||
def test_pdf_constraints_creation(self):
|
||||
"""Test creating PDF constraints."""
|
||||
constraints = PDFConstraints(
|
||||
max_size_bytes=30 * 1024 * 1024,
|
||||
max_pages=100,
|
||||
)
|
||||
|
||||
assert constraints.max_size_bytes == 30 * 1024 * 1024
|
||||
assert constraints.max_pages == 100
|
||||
|
||||
def test_pdf_constraints_defaults(self):
|
||||
"""Test PDF constraints with default values."""
|
||||
constraints = PDFConstraints(max_size_bytes=1000)
|
||||
|
||||
assert constraints.max_size_bytes == 1000
|
||||
assert constraints.max_pages is None
|
||||
|
||||
|
||||
class TestAudioConstraints:
|
||||
"""Tests for AudioConstraints dataclass."""
|
||||
|
||||
def test_audio_constraints_creation(self):
|
||||
"""Test creating audio constraints."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=100 * 1024 * 1024,
|
||||
max_duration_seconds=3600,
|
||||
)
|
||||
|
||||
assert constraints.max_size_bytes == 100 * 1024 * 1024
|
||||
assert constraints.max_duration_seconds == 3600
|
||||
assert "audio/mp3" in constraints.supported_formats
|
||||
|
||||
|
||||
class TestVideoConstraints:
|
||||
"""Tests for VideoConstraints dataclass."""
|
||||
|
||||
def test_video_constraints_creation(self):
|
||||
"""Test creating video constraints."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=2 * 1024 * 1024 * 1024,
|
||||
max_duration_seconds=7200,
|
||||
)
|
||||
|
||||
assert constraints.max_size_bytes == 2 * 1024 * 1024 * 1024
|
||||
assert constraints.max_duration_seconds == 7200
|
||||
assert "video/mp4" in constraints.supported_formats
|
||||
|
||||
|
||||
class TestProviderConstraints:
|
||||
"""Tests for ProviderConstraints dataclass."""
|
||||
|
||||
def test_provider_constraints_creation(self):
|
||||
"""Test creating full provider constraints."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test-provider",
|
||||
image=ImageConstraints(max_size_bytes=5 * 1024 * 1024),
|
||||
pdf=PDFConstraints(max_size_bytes=30 * 1024 * 1024),
|
||||
supports_file_upload=True,
|
||||
file_upload_threshold_bytes=10 * 1024 * 1024,
|
||||
)
|
||||
|
||||
assert constraints.name == "test-provider"
|
||||
assert constraints.image is not None
|
||||
assert constraints.pdf is not None
|
||||
assert constraints.supports_file_upload is True
|
||||
|
||||
def test_provider_constraints_defaults(self):
|
||||
"""Test provider constraints with default values."""
|
||||
constraints = ProviderConstraints(name="test")
|
||||
|
||||
assert constraints.name == "test"
|
||||
assert constraints.image is None
|
||||
assert constraints.pdf is None
|
||||
assert constraints.audio is None
|
||||
assert constraints.video is None
|
||||
assert constraints.supports_file_upload is False
|
||||
|
||||
|
||||
class TestPredefinedConstraints:
|
||||
"""Tests for predefined provider constraints."""
|
||||
|
||||
def test_anthropic_constraints(self):
|
||||
"""Test Anthropic constraints are properly defined."""
|
||||
assert ANTHROPIC_CONSTRAINTS.name == "anthropic"
|
||||
assert ANTHROPIC_CONSTRAINTS.image is not None
|
||||
assert ANTHROPIC_CONSTRAINTS.image.max_size_bytes == 5 * 1024 * 1024
|
||||
assert ANTHROPIC_CONSTRAINTS.image.max_width == 8000
|
||||
assert ANTHROPIC_CONSTRAINTS.pdf is not None
|
||||
assert ANTHROPIC_CONSTRAINTS.pdf.max_pages == 100
|
||||
assert ANTHROPIC_CONSTRAINTS.supports_file_upload is True
|
||||
|
||||
def test_openai_constraints(self):
|
||||
"""Test OpenAI constraints are properly defined."""
|
||||
assert OPENAI_CONSTRAINTS.name == "openai"
|
||||
assert OPENAI_CONSTRAINTS.image is not None
|
||||
assert OPENAI_CONSTRAINTS.image.max_size_bytes == 20 * 1024 * 1024
|
||||
assert OPENAI_CONSTRAINTS.pdf is None # OpenAI doesn't support PDFs
|
||||
|
||||
def test_gemini_constraints(self):
|
||||
"""Test Gemini constraints are properly defined."""
|
||||
assert GEMINI_CONSTRAINTS.name == "gemini"
|
||||
assert GEMINI_CONSTRAINTS.image is not None
|
||||
assert GEMINI_CONSTRAINTS.pdf is not None
|
||||
assert GEMINI_CONSTRAINTS.audio is not None
|
||||
assert GEMINI_CONSTRAINTS.video is not None
|
||||
assert GEMINI_CONSTRAINTS.supports_file_upload is True
|
||||
|
||||
def test_bedrock_constraints(self):
|
||||
"""Test Bedrock constraints are properly defined."""
|
||||
assert BEDROCK_CONSTRAINTS.name == "bedrock"
|
||||
assert BEDROCK_CONSTRAINTS.image is not None
|
||||
assert BEDROCK_CONSTRAINTS.image.max_size_bytes == 4_608_000
|
||||
assert BEDROCK_CONSTRAINTS.pdf is not None
|
||||
assert BEDROCK_CONSTRAINTS.supports_file_upload is False
|
||||
|
||||
|
||||
class TestGetConstraintsForProvider:
|
||||
"""Tests for get_constraints_for_provider function."""
|
||||
|
||||
def test_get_by_exact_name(self):
|
||||
"""Test getting constraints by exact provider name."""
|
||||
result = get_constraints_for_provider("anthropic")
|
||||
assert result == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("openai")
|
||||
assert result == OPENAI_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("gemini")
|
||||
assert result == GEMINI_CONSTRAINTS
|
||||
|
||||
def test_get_by_alias(self):
|
||||
"""Test getting constraints by alias name."""
|
||||
result = get_constraints_for_provider("claude")
|
||||
assert result == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("gpt")
|
||||
assert result == OPENAI_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("google")
|
||||
assert result == GEMINI_CONSTRAINTS
|
||||
|
||||
def test_get_case_insensitive(self):
|
||||
"""Test case-insensitive lookup."""
|
||||
result = get_constraints_for_provider("ANTHROPIC")
|
||||
assert result == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("OpenAI")
|
||||
assert result == OPENAI_CONSTRAINTS
|
||||
|
||||
def test_get_with_provider_constraints_object(self):
|
||||
"""Test passing ProviderConstraints object returns it unchanged."""
|
||||
custom = ProviderConstraints(name="custom")
|
||||
result = get_constraints_for_provider(custom)
|
||||
assert result is custom
|
||||
|
||||
def test_get_unknown_provider(self):
|
||||
"""Test unknown provider returns None."""
|
||||
result = get_constraints_for_provider("unknown-provider")
|
||||
assert result is None
|
||||
|
||||
def test_get_by_partial_match(self):
|
||||
"""Test partial match in provider string."""
|
||||
result = get_constraints_for_provider("claude-3-sonnet")
|
||||
assert result == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("gpt-4o")
|
||||
assert result == OPENAI_CONSTRAINTS
|
||||
|
||||
result = get_constraints_for_provider("gemini-pro")
|
||||
assert result == GEMINI_CONSTRAINTS
|
||||
@@ -1,220 +0,0 @@
|
||||
"""Tests for FileProcessor class."""
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.files import FileBytes, ImageFile, PDFFile, TextFile
|
||||
from crewai.files.processing.constraints import (
|
||||
ANTHROPIC_CONSTRAINTS,
|
||||
ImageConstraints,
|
||||
PDFConstraints,
|
||||
ProviderConstraints,
|
||||
)
|
||||
from crewai.files.processing.enums import FileHandling
|
||||
from crewai.files.processing.exceptions import (
|
||||
FileTooLargeError,
|
||||
FileValidationError,
|
||||
)
|
||||
from crewai.files.processing.processor import FileProcessor
|
||||
|
||||
|
||||
# Minimal valid PNG: 8x8 pixel RGB image (valid for PIL)
|
||||
MINIMAL_PNG = bytes([
|
||||
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d,
|
||||
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08,
|
||||
0x08, 0x02, 0x00, 0x00, 0x00, 0x4b, 0x6d, 0x29, 0xdc, 0x00, 0x00, 0x00,
|
||||
0x12, 0x49, 0x44, 0x41, 0x54, 0x78, 0x9c, 0x63, 0xfc, 0xcf, 0x80, 0x1d,
|
||||
0x30, 0xe1, 0x10, 0x1f, 0xa4, 0x12, 0x00, 0xcd, 0x41, 0x01, 0x0f, 0xe8,
|
||||
0x41, 0xe2, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae,
|
||||
0x42, 0x60, 0x82,
|
||||
])
|
||||
|
||||
# Minimal valid PDF
|
||||
MINIMAL_PDF = (
|
||||
b"%PDF-1.4\n1 0 obj<</Type/Catalog/Pages 2 0 R>>endobj "
|
||||
b"2 0 obj<</Type/Pages/Kids[3 0 R]/Count 1>>endobj "
|
||||
b"3 0 obj<</Type/Page/MediaBox[0 0 612 792]/Parent 2 0 R>>endobj "
|
||||
b"xref\n0 4\n0000000000 65535 f \n0000000009 00000 n \n"
|
||||
b"0000000052 00000 n \n0000000101 00000 n \n"
|
||||
b"trailer<</Size 4/Root 1 0 R>>\nstartxref\n178\n%%EOF"
|
||||
)
|
||||
|
||||
|
||||
class TestFileProcessorInit:
|
||||
"""Tests for FileProcessor initialization."""
|
||||
|
||||
def test_init_with_constraints(self):
|
||||
"""Test initialization with ProviderConstraints."""
|
||||
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
|
||||
|
||||
assert processor.constraints == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
def test_init_with_provider_string(self):
|
||||
"""Test initialization with provider name string."""
|
||||
processor = FileProcessor(constraints="anthropic")
|
||||
|
||||
assert processor.constraints == ANTHROPIC_CONSTRAINTS
|
||||
|
||||
def test_init_with_unknown_provider(self):
|
||||
"""Test initialization with unknown provider sets constraints to None."""
|
||||
processor = FileProcessor(constraints="unknown")
|
||||
|
||||
assert processor.constraints is None
|
||||
|
||||
def test_init_with_none_constraints(self):
|
||||
"""Test initialization with None constraints."""
|
||||
processor = FileProcessor(constraints=None)
|
||||
|
||||
assert processor.constraints is None
|
||||
|
||||
|
||||
class TestFileProcessorValidate:
|
||||
"""Tests for FileProcessor.validate method."""
|
||||
|
||||
def test_validate_valid_file(self):
|
||||
"""Test validating a valid file returns no errors."""
|
||||
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
errors = processor.validate(file)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_without_constraints(self):
|
||||
"""Test validating without constraints returns empty list."""
|
||||
processor = FileProcessor(constraints=None)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
errors = processor.validate(file)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_strict_raises_on_error(self):
|
||||
"""Test STRICT mode raises on validation error."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
# Set mode to strict on the file
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict")
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
processor.validate(file)
|
||||
|
||||
|
||||
class TestFileProcessorProcess:
|
||||
"""Tests for FileProcessor.process method."""
|
||||
|
||||
def test_process_valid_file(self):
|
||||
"""Test processing a valid file returns it unchanged."""
|
||||
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
result = processor.process(file)
|
||||
|
||||
assert result == file
|
||||
|
||||
def test_process_without_constraints(self):
|
||||
"""Test processing without constraints returns file unchanged."""
|
||||
processor = FileProcessor(constraints=None)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
result = processor.process(file)
|
||||
|
||||
assert result == file
|
||||
|
||||
def test_process_strict_raises_on_error(self):
|
||||
"""Test STRICT mode raises on processing error."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
# Set mode to strict on the file
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict")
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
processor.process(file)
|
||||
|
||||
def test_process_warn_returns_file(self):
|
||||
"""Test WARN mode returns file with warning."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
# Set mode to warn on the file
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="warn")
|
||||
|
||||
result = processor.process(file)
|
||||
|
||||
assert result == file
|
||||
|
||||
|
||||
class TestFileProcessorProcessFiles:
|
||||
"""Tests for FileProcessor.process_files method."""
|
||||
|
||||
def test_process_files_multiple(self):
|
||||
"""Test processing multiple files."""
|
||||
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
|
||||
files = {
|
||||
"image1": ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png")),
|
||||
"image2": ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test2.png")),
|
||||
}
|
||||
|
||||
result = processor.process_files(files)
|
||||
|
||||
assert len(result) == 2
|
||||
assert "image1" in result
|
||||
assert "image2" in result
|
||||
|
||||
def test_process_files_empty(self):
|
||||
"""Test processing empty files dict."""
|
||||
processor = FileProcessor(constraints=ANTHROPIC_CONSTRAINTS)
|
||||
|
||||
result = processor.process_files({})
|
||||
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestFileHandlingEnum:
|
||||
"""Tests for FileHandling enum."""
|
||||
|
||||
def test_enum_values(self):
|
||||
"""Test all enum values are accessible."""
|
||||
assert FileHandling.STRICT.value == "strict"
|
||||
assert FileHandling.AUTO.value == "auto"
|
||||
assert FileHandling.WARN.value == "warn"
|
||||
assert FileHandling.CHUNK.value == "chunk"
|
||||
|
||||
|
||||
class TestFileProcessorPerFileMode:
|
||||
"""Tests for per-file mode handling."""
|
||||
|
||||
def test_file_default_mode_is_auto(self):
|
||||
"""Test that files default to auto mode."""
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
assert file.mode == "auto"
|
||||
|
||||
def test_file_custom_mode(self):
|
||||
"""Test setting custom mode on file."""
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict")
|
||||
assert file.mode == "strict"
|
||||
|
||||
def test_processor_respects_file_mode(self):
|
||||
"""Test processor uses each file's mode setting."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
image=ImageConstraints(max_size_bytes=10),
|
||||
)
|
||||
processor = FileProcessor(constraints=constraints)
|
||||
|
||||
# File with strict mode should raise
|
||||
strict_file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="strict")
|
||||
with pytest.raises(FileTooLargeError):
|
||||
processor.process(strict_file)
|
||||
|
||||
# File with warn mode should not raise
|
||||
warn_file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"), mode="warn")
|
||||
result = processor.process(warn_file)
|
||||
assert result == warn_file
|
||||
@@ -1,359 +0,0 @@
|
||||
"""Unit tests for file transformers."""
|
||||
|
||||
import io
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.files import ImageFile, PDFFile, TextFile
|
||||
from crewai.files.file import FileBytes
|
||||
from crewai.files.processing.exceptions import ProcessingDependencyError
|
||||
from crewai.files.processing.transformers import (
|
||||
chunk_pdf,
|
||||
chunk_text,
|
||||
get_image_dimensions,
|
||||
get_pdf_page_count,
|
||||
optimize_image,
|
||||
resize_image,
|
||||
)
|
||||
|
||||
|
||||
def create_test_png(width: int = 100, height: int = 100) -> bytes:
|
||||
"""Create a minimal valid PNG for testing."""
|
||||
from PIL import Image
|
||||
|
||||
img = Image.new("RGB", (width, height), color="red")
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def create_test_pdf(num_pages: int = 1) -> bytes:
|
||||
"""Create a minimal valid PDF for testing."""
|
||||
from pypdf import PdfWriter
|
||||
|
||||
writer = PdfWriter()
|
||||
for _ in range(num_pages):
|
||||
writer.add_blank_page(width=612, height=792)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
writer.write(buffer)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
class TestResizeImage:
|
||||
"""Tests for resize_image function."""
|
||||
|
||||
def test_resize_larger_image(self) -> None:
|
||||
"""Test resizing an image larger than max dimensions."""
|
||||
png_bytes = create_test_png(200, 150)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
result = resize_image(img, max_width=100, max_height=100)
|
||||
|
||||
dims = get_image_dimensions(result)
|
||||
assert dims is not None
|
||||
width, height = dims
|
||||
assert width <= 100
|
||||
assert height <= 100
|
||||
|
||||
def test_no_resize_if_within_bounds(self) -> None:
|
||||
"""Test that small images are returned unchanged."""
|
||||
png_bytes = create_test_png(50, 50)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="small.png"))
|
||||
|
||||
result = resize_image(img, max_width=100, max_height=100)
|
||||
|
||||
assert result is img
|
||||
|
||||
def test_preserve_aspect_ratio(self) -> None:
|
||||
"""Test that aspect ratio is preserved during resize."""
|
||||
png_bytes = create_test_png(200, 100)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="wide.png"))
|
||||
|
||||
result = resize_image(img, max_width=100, max_height=100)
|
||||
|
||||
dims = get_image_dimensions(result)
|
||||
assert dims is not None
|
||||
width, height = dims
|
||||
assert width == 100
|
||||
assert height == 50
|
||||
|
||||
def test_resize_without_aspect_ratio(self) -> None:
|
||||
"""Test resizing without preserving aspect ratio."""
|
||||
png_bytes = create_test_png(200, 100)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="wide.png"))
|
||||
|
||||
result = resize_image(
|
||||
img, max_width=50, max_height=50, preserve_aspect_ratio=False
|
||||
)
|
||||
|
||||
dims = get_image_dimensions(result)
|
||||
assert dims is not None
|
||||
width, height = dims
|
||||
assert width == 50
|
||||
assert height == 50
|
||||
|
||||
def test_resize_returns_image_file(self) -> None:
|
||||
"""Test that resize returns an ImageFile instance."""
|
||||
png_bytes = create_test_png(200, 200)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
result = resize_image(img, max_width=100, max_height=100)
|
||||
|
||||
assert isinstance(result, ImageFile)
|
||||
|
||||
def test_raises_without_pillow(self) -> None:
|
||||
"""Test that ProcessingDependencyError is raised without Pillow."""
|
||||
img = ImageFile(source=FileBytes(data=b"fake", filename="test.png"))
|
||||
|
||||
with patch.dict("sys.modules", {"PIL": None, "PIL.Image": None}):
|
||||
with pytest.raises(ProcessingDependencyError) as exc_info:
|
||||
# Force reimport to trigger ImportError
|
||||
import importlib
|
||||
|
||||
import crewai.files.processing.transformers as t
|
||||
|
||||
importlib.reload(t)
|
||||
t.resize_image(img, 100, 100)
|
||||
|
||||
assert "Pillow" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestOptimizeImage:
|
||||
"""Tests for optimize_image function."""
|
||||
|
||||
def test_optimize_reduces_size(self) -> None:
|
||||
"""Test that optimization reduces file size."""
|
||||
png_bytes = create_test_png(500, 500)
|
||||
original_size = len(png_bytes)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="large.png"))
|
||||
|
||||
result = optimize_image(img, target_size_bytes=original_size // 2)
|
||||
|
||||
result_size = len(result.read())
|
||||
assert result_size < original_size
|
||||
|
||||
def test_no_optimize_if_under_target(self) -> None:
|
||||
"""Test that small images are returned unchanged."""
|
||||
png_bytes = create_test_png(50, 50)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="small.png"))
|
||||
|
||||
result = optimize_image(img, target_size_bytes=1024 * 1024)
|
||||
|
||||
assert result is img
|
||||
|
||||
def test_optimize_returns_image_file(self) -> None:
|
||||
"""Test that optimize returns an ImageFile instance."""
|
||||
png_bytes = create_test_png(200, 200)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
result = optimize_image(img, target_size_bytes=100)
|
||||
|
||||
assert isinstance(result, ImageFile)
|
||||
|
||||
def test_optimize_respects_min_quality(self) -> None:
|
||||
"""Test that optimization stops at minimum quality."""
|
||||
png_bytes = create_test_png(100, 100)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
# Request impossibly small size - should stop at min quality
|
||||
result = optimize_image(img, target_size_bytes=10, min_quality=50)
|
||||
|
||||
assert isinstance(result, ImageFile)
|
||||
assert len(result.read()) > 10
|
||||
|
||||
|
||||
class TestChunkPdf:
|
||||
"""Tests for chunk_pdf function."""
|
||||
|
||||
def test_chunk_splits_large_pdf(self) -> None:
|
||||
"""Test that large PDFs are split into chunks."""
|
||||
pdf_bytes = create_test_pdf(num_pages=10)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="large.pdf"))
|
||||
|
||||
result = list(chunk_pdf(pdf, max_pages=3))
|
||||
|
||||
assert len(result) == 4
|
||||
assert all(isinstance(chunk, PDFFile) for chunk in result)
|
||||
|
||||
def test_no_chunk_if_within_limit(self) -> None:
|
||||
"""Test that small PDFs are returned unchanged."""
|
||||
pdf_bytes = create_test_pdf(num_pages=3)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="small.pdf"))
|
||||
|
||||
result = list(chunk_pdf(pdf, max_pages=5))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] is pdf
|
||||
|
||||
def test_chunk_filenames(self) -> None:
|
||||
"""Test that chunked files have indexed filenames."""
|
||||
pdf_bytes = create_test_pdf(num_pages=6)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="document.pdf"))
|
||||
|
||||
result = list(chunk_pdf(pdf, max_pages=2))
|
||||
|
||||
assert result[0].filename == "document_chunk_0.pdf"
|
||||
assert result[1].filename == "document_chunk_1.pdf"
|
||||
assert result[2].filename == "document_chunk_2.pdf"
|
||||
|
||||
def test_chunk_with_overlap(self) -> None:
|
||||
"""Test chunking with overlapping pages."""
|
||||
pdf_bytes = create_test_pdf(num_pages=10)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="doc.pdf"))
|
||||
|
||||
result = list(chunk_pdf(pdf, max_pages=4, overlap_pages=1))
|
||||
|
||||
# With overlap, we get more chunks
|
||||
assert len(result) >= 3
|
||||
|
||||
def test_chunk_page_counts(self) -> None:
|
||||
"""Test that each chunk has correct page count."""
|
||||
pdf_bytes = create_test_pdf(num_pages=7)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="doc.pdf"))
|
||||
|
||||
result = list(chunk_pdf(pdf, max_pages=3))
|
||||
|
||||
page_counts = [get_pdf_page_count(chunk) for chunk in result]
|
||||
assert page_counts == [3, 3, 1]
|
||||
|
||||
|
||||
class TestChunkText:
|
||||
"""Tests for chunk_text function."""
|
||||
|
||||
def test_chunk_splits_large_text(self) -> None:
|
||||
"""Test that large text files are split into chunks."""
|
||||
content = "Hello world. " * 100
|
||||
text = TextFile(source=content.encode(), filename="large.txt")
|
||||
|
||||
result = list(chunk_text(text, max_chars=200, overlap_chars=0))
|
||||
|
||||
assert len(result) > 1
|
||||
assert all(isinstance(chunk, TextFile) for chunk in result)
|
||||
|
||||
def test_no_chunk_if_within_limit(self) -> None:
|
||||
"""Test that small text files are returned unchanged."""
|
||||
content = "Short text"
|
||||
text = TextFile(source=content.encode(), filename="small.txt")
|
||||
|
||||
result = list(chunk_text(text, max_chars=1000, overlap_chars=0))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0] is text
|
||||
|
||||
def test_chunk_filenames(self) -> None:
|
||||
"""Test that chunked files have indexed filenames."""
|
||||
content = "A" * 500
|
||||
text = TextFile(source=FileBytes(data=content.encode(), filename="data.txt"))
|
||||
|
||||
result = list(chunk_text(text, max_chars=200, overlap_chars=0))
|
||||
|
||||
assert result[0].filename == "data_chunk_0.txt"
|
||||
assert result[1].filename == "data_chunk_1.txt"
|
||||
assert len(result) == 3
|
||||
|
||||
def test_chunk_preserves_extension(self) -> None:
|
||||
"""Test that file extension is preserved in chunks."""
|
||||
content = "A" * 500
|
||||
text = TextFile(source=FileBytes(data=content.encode(), filename="script.py"))
|
||||
|
||||
result = list(chunk_text(text, max_chars=200, overlap_chars=0))
|
||||
|
||||
assert all(chunk.filename.endswith(".py") for chunk in result)
|
||||
|
||||
def test_chunk_prefers_newline_boundaries(self) -> None:
|
||||
"""Test that chunking prefers to split at newlines."""
|
||||
content = "Line one\nLine two\nLine three\nLine four\nLine five"
|
||||
text = TextFile(source=content.encode(), filename="lines.txt")
|
||||
|
||||
result = list(chunk_text(text, max_chars=25, overlap_chars=0, split_on_newlines=True))
|
||||
|
||||
# Should split at newline boundaries
|
||||
for chunk in result:
|
||||
chunk_text_content = chunk.read().decode()
|
||||
# Chunks should end at newlines (except possibly the last)
|
||||
if chunk != result[-1]:
|
||||
assert chunk_text_content.endswith("\n") or len(chunk_text_content) <= 25
|
||||
|
||||
def test_chunk_with_overlap(self) -> None:
|
||||
"""Test chunking with overlapping characters."""
|
||||
content = "ABCDEFGHIJ" * 10
|
||||
text = TextFile(source=content.encode(), filename="data.txt")
|
||||
|
||||
result = list(chunk_text(text, max_chars=30, overlap_chars=5))
|
||||
|
||||
# With overlap, chunks should share some content
|
||||
assert len(result) >= 3
|
||||
|
||||
def test_chunk_overlap_larger_than_max_chars(self) -> None:
|
||||
"""Test that overlap > max_chars doesn't cause infinite loop."""
|
||||
content = "A" * 100
|
||||
text = TextFile(source=content.encode(), filename="data.txt")
|
||||
|
||||
# overlap_chars > max_chars should still work (just with max overlap)
|
||||
result = list(chunk_text(text, max_chars=20, overlap_chars=50))
|
||||
|
||||
assert len(result) > 1
|
||||
# Should still complete without hanging
|
||||
|
||||
|
||||
class TestGetImageDimensions:
|
||||
"""Tests for get_image_dimensions function."""
|
||||
|
||||
def test_get_dimensions(self) -> None:
|
||||
"""Test getting image dimensions."""
|
||||
png_bytes = create_test_png(150, 100)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
dims = get_image_dimensions(img)
|
||||
|
||||
assert dims == (150, 100)
|
||||
|
||||
def test_returns_none_for_invalid_image(self) -> None:
|
||||
"""Test that None is returned for invalid image data."""
|
||||
img = ImageFile(source=FileBytes(data=b"not an image", filename="bad.png"))
|
||||
|
||||
dims = get_image_dimensions(img)
|
||||
|
||||
assert dims is None
|
||||
|
||||
def test_returns_none_without_pillow(self) -> None:
|
||||
"""Test that None is returned when Pillow is not installed."""
|
||||
png_bytes = create_test_png(100, 100)
|
||||
img = ImageFile(source=FileBytes(data=png_bytes, filename="test.png"))
|
||||
|
||||
with patch.dict("sys.modules", {"PIL": None}):
|
||||
# Can't easily test this without unloading module
|
||||
# Just verify the function handles the case gracefully
|
||||
pass
|
||||
|
||||
|
||||
class TestGetPdfPageCount:
|
||||
"""Tests for get_pdf_page_count function."""
|
||||
|
||||
def test_get_page_count(self) -> None:
|
||||
"""Test getting PDF page count."""
|
||||
pdf_bytes = create_test_pdf(num_pages=5)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="test.pdf"))
|
||||
|
||||
count = get_pdf_page_count(pdf)
|
||||
|
||||
assert count == 5
|
||||
|
||||
def test_single_page(self) -> None:
|
||||
"""Test page count for single page PDF."""
|
||||
pdf_bytes = create_test_pdf(num_pages=1)
|
||||
pdf = PDFFile(source=FileBytes(data=pdf_bytes, filename="single.pdf"))
|
||||
|
||||
count = get_pdf_page_count(pdf)
|
||||
|
||||
assert count == 1
|
||||
|
||||
def test_returns_none_for_invalid_pdf(self) -> None:
|
||||
"""Test that None is returned for invalid PDF data."""
|
||||
pdf = PDFFile(source=FileBytes(data=b"not a pdf", filename="bad.pdf"))
|
||||
|
||||
count = get_pdf_page_count(pdf)
|
||||
|
||||
assert count is None
|
||||
@@ -1,575 +0,0 @@
|
||||
"""Tests for file validators."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.files import AudioFile, FileBytes, ImageFile, PDFFile, TextFile, VideoFile
|
||||
from crewai.files.processing.constraints import (
|
||||
ANTHROPIC_CONSTRAINTS,
|
||||
AudioConstraints,
|
||||
ImageConstraints,
|
||||
PDFConstraints,
|
||||
ProviderConstraints,
|
||||
VideoConstraints,
|
||||
)
|
||||
from crewai.files.processing.exceptions import (
|
||||
FileTooLargeError,
|
||||
FileValidationError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from crewai.files.processing.validators import (
|
||||
_get_audio_duration,
|
||||
_get_video_duration,
|
||||
validate_audio,
|
||||
validate_file,
|
||||
validate_image,
|
||||
validate_pdf,
|
||||
validate_text,
|
||||
validate_video,
|
||||
)
|
||||
|
||||
|
||||
# Minimal valid PNG: 8x8 pixel RGB image (valid for PIL)
|
||||
MINIMAL_PNG = bytes([
|
||||
0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d,
|
||||
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08,
|
||||
0x08, 0x02, 0x00, 0x00, 0x00, 0x4b, 0x6d, 0x29, 0xdc, 0x00, 0x00, 0x00,
|
||||
0x12, 0x49, 0x44, 0x41, 0x54, 0x78, 0x9c, 0x63, 0xfc, 0xcf, 0x80, 0x1d,
|
||||
0x30, 0xe1, 0x10, 0x1f, 0xa4, 0x12, 0x00, 0xcd, 0x41, 0x01, 0x0f, 0xe8,
|
||||
0x41, 0xe2, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae,
|
||||
0x42, 0x60, 0x82,
|
||||
])
|
||||
|
||||
# Minimal valid PDF
|
||||
MINIMAL_PDF = (
|
||||
b"%PDF-1.4\n1 0 obj<</Type/Catalog/Pages 2 0 R>>endobj "
|
||||
b"2 0 obj<</Type/Pages/Kids[3 0 R]/Count 1>>endobj "
|
||||
b"3 0 obj<</Type/Page/MediaBox[0 0 612 792]/Parent 2 0 R>>endobj "
|
||||
b"xref\n0 4\n0000000000 65535 f \n0000000009 00000 n \n"
|
||||
b"0000000052 00000 n \n0000000101 00000 n \n"
|
||||
b"trailer<</Size 4/Root 1 0 R>>\nstartxref\n178\n%%EOF"
|
||||
)
|
||||
|
||||
|
||||
class TestValidateImage:
|
||||
"""Tests for validate_image function."""
|
||||
|
||||
def test_validate_valid_image(self):
|
||||
"""Test validating a valid image within constraints."""
|
||||
constraints = ImageConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("image/png",),
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
errors = validate_image(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_image_too_large(self):
|
||||
"""Test validating an image that exceeds size limit."""
|
||||
constraints = ImageConstraints(
|
||||
max_size_bytes=10, # Very small limit
|
||||
supported_formats=("image/png",),
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
with pytest.raises(FileTooLargeError) as exc_info:
|
||||
validate_image(file, constraints)
|
||||
|
||||
assert "exceeds" in str(exc_info.value)
|
||||
assert exc_info.value.file_name == "test.png"
|
||||
|
||||
def test_validate_image_unsupported_format(self):
|
||||
"""Test validating an image with unsupported format."""
|
||||
constraints = ImageConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("image/jpeg",), # Only JPEG
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_image(file, constraints)
|
||||
|
||||
assert "not supported" in str(exc_info.value)
|
||||
|
||||
def test_validate_image_no_raise(self):
|
||||
"""Test validating with raise_on_error=False returns errors list."""
|
||||
constraints = ImageConstraints(
|
||||
max_size_bytes=10,
|
||||
supported_formats=("image/jpeg",),
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
errors = validate_image(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 2 # Size error and format error
|
||||
|
||||
|
||||
class TestValidatePDF:
|
||||
"""Tests for validate_pdf function."""
|
||||
|
||||
def test_validate_valid_pdf(self):
|
||||
"""Test validating a valid PDF within constraints."""
|
||||
constraints = PDFConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
)
|
||||
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
|
||||
|
||||
errors = validate_pdf(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_pdf_too_large(self):
|
||||
"""Test validating a PDF that exceeds size limit."""
|
||||
constraints = PDFConstraints(
|
||||
max_size_bytes=10, # Very small limit
|
||||
)
|
||||
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
|
||||
|
||||
with pytest.raises(FileTooLargeError) as exc_info:
|
||||
validate_pdf(file, constraints)
|
||||
|
||||
assert "exceeds" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestValidateText:
|
||||
"""Tests for validate_text function."""
|
||||
|
||||
def test_validate_valid_text(self):
|
||||
"""Test validating a valid text file."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
general_max_size_bytes=10 * 1024 * 1024,
|
||||
)
|
||||
file = TextFile(source=FileBytes(data=b"Hello, World!", filename="test.txt"))
|
||||
|
||||
errors = validate_text(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_text_too_large(self):
|
||||
"""Test validating text that exceeds size limit."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
general_max_size_bytes=5,
|
||||
)
|
||||
file = TextFile(source=FileBytes(data=b"Hello, World!", filename="test.txt"))
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
validate_text(file, constraints)
|
||||
|
||||
def test_validate_text_no_limit(self):
|
||||
"""Test validating text with no size limit."""
|
||||
constraints = ProviderConstraints(name="test")
|
||||
file = TextFile(source=FileBytes(data=b"Hello, World!", filename="test.txt"))
|
||||
|
||||
errors = validate_text(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
class TestValidateFile:
|
||||
"""Tests for validate_file function."""
|
||||
|
||||
def test_validate_file_dispatches_to_image(self):
|
||||
"""Test validate_file dispatches to image validator."""
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
errors = validate_file(file, ANTHROPIC_CONSTRAINTS, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_file_dispatches_to_pdf(self):
|
||||
"""Test validate_file dispatches to PDF validator."""
|
||||
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
|
||||
|
||||
errors = validate_file(file, ANTHROPIC_CONSTRAINTS, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_file_unsupported_type(self):
|
||||
"""Test validating a file type not supported by provider."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
image=None, # No image support
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_file(file, constraints)
|
||||
|
||||
assert "does not support images" in str(exc_info.value)
|
||||
|
||||
def test_validate_file_pdf_not_supported(self):
|
||||
"""Test validating PDF when provider doesn't support it."""
|
||||
constraints = ProviderConstraints(
|
||||
name="test",
|
||||
pdf=None, # No PDF support
|
||||
)
|
||||
file = PDFFile(source=FileBytes(data=MINIMAL_PDF, filename="test.pdf"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_file(file, constraints)
|
||||
|
||||
assert "does not support PDFs" in str(exc_info.value)
|
||||
|
||||
|
||||
# Minimal audio bytes for testing (not a valid audio file, used for mocked tests)
|
||||
MINIMAL_AUDIO = b"\x00" * 100
|
||||
|
||||
# Minimal video bytes for testing (not a valid video file, used for mocked tests)
|
||||
MINIMAL_VIDEO = b"\x00" * 100
|
||||
|
||||
# Fallback content type when python-magic cannot detect
|
||||
FALLBACK_CONTENT_TYPE = "application/octet-stream"
|
||||
|
||||
|
||||
class TestValidateAudio:
|
||||
"""Tests for validate_audio function and audio duration validation."""
|
||||
|
||||
def test_validate_valid_audio(self):
|
||||
"""Test validating a valid audio file within constraints."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_audio_too_large(self):
|
||||
"""Test validating an audio file that exceeds size limit."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10, # Very small limit
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
with pytest.raises(FileTooLargeError) as exc_info:
|
||||
validate_audio(file, constraints)
|
||||
|
||||
assert "exceeds" in str(exc_info.value)
|
||||
assert exc_info.value.file_name == "test.mp3"
|
||||
|
||||
def test_validate_audio_unsupported_format(self):
|
||||
"""Test validating an audio file with unsupported format."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("audio/wav",), # Only WAV
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_audio(file, constraints)
|
||||
|
||||
assert "not supported" in str(exc_info.value)
|
||||
|
||||
@patch("crewai.files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_passes(self, mock_get_duration):
|
||||
"""Test validating audio when duration is under limit."""
|
||||
mock_get_duration.return_value = 30.0
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_called_once()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_fails(self, mock_get_duration):
|
||||
"""Test validating audio when duration exceeds limit."""
|
||||
mock_get_duration.return_value = 120.5
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
with pytest.raises(FileValidationError) as exc_info:
|
||||
validate_audio(file, constraints)
|
||||
|
||||
assert "duration" in str(exc_info.value).lower()
|
||||
assert "120.5s" in str(exc_info.value)
|
||||
assert "60s" in str(exc_info.value)
|
||||
|
||||
@patch("crewai.files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_no_raise(self, mock_get_duration):
|
||||
"""Test audio duration validation with raise_on_error=False."""
|
||||
mock_get_duration.return_value = 120.5
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert "duration" in errors[0].lower()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_none_skips(self, mock_get_duration):
|
||||
"""Test that duration validation is skipped when max_duration_seconds is None."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=None,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_not_called()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_detection_returns_none(self, mock_get_duration):
|
||||
"""Test that validation passes when duration detection returns None."""
|
||||
mock_get_duration.return_value = None
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
class TestValidateVideo:
|
||||
"""Tests for validate_video function and video duration validation."""
|
||||
|
||||
def test_validate_valid_video(self):
|
||||
"""Test validating a valid video file within constraints."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_video_too_large(self):
|
||||
"""Test validating a video file that exceeds size limit."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10, # Very small limit
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
with pytest.raises(FileTooLargeError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "exceeds" in str(exc_info.value)
|
||||
assert exc_info.value.file_name == "test.mp4"
|
||||
|
||||
def test_validate_video_unsupported_format(self):
|
||||
"""Test validating a video file with unsupported format."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("video/webm",), # Only WebM
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "not supported" in str(exc_info.value)
|
||||
|
||||
@patch("crewai.files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_passes(self, mock_get_duration):
|
||||
"""Test validating video when duration is under limit."""
|
||||
mock_get_duration.return_value = 30.0
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_called_once()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_fails(self, mock_get_duration):
|
||||
"""Test validating video when duration exceeds limit."""
|
||||
mock_get_duration.return_value = 180.0
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
with pytest.raises(FileValidationError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "duration" in str(exc_info.value).lower()
|
||||
assert "180.0s" in str(exc_info.value)
|
||||
assert "60s" in str(exc_info.value)
|
||||
|
||||
@patch("crewai.files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_no_raise(self, mock_get_duration):
|
||||
"""Test video duration validation with raise_on_error=False."""
|
||||
mock_get_duration.return_value = 180.0
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert "duration" in errors[0].lower()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_none_skips(self, mock_get_duration):
|
||||
"""Test that duration validation is skipped when max_duration_seconds is None."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=None,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_not_called()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_detection_returns_none(self, mock_get_duration):
|
||||
"""Test that validation passes when duration detection returns None."""
|
||||
mock_get_duration.return_value = None
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
class TestGetAudioDuration:
|
||||
"""Tests for _get_audio_duration helper function."""
|
||||
|
||||
def test_get_audio_duration_corrupt_file(self):
|
||||
"""Test handling of corrupt audio data."""
|
||||
corrupt_data = b"not valid audio data at all"
|
||||
result = _get_audio_duration(corrupt_data)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetVideoDuration:
|
||||
"""Tests for _get_video_duration helper function."""
|
||||
|
||||
def test_get_video_duration_corrupt_file(self):
|
||||
"""Test handling of corrupt video data."""
|
||||
corrupt_data = b"not valid video data at all"
|
||||
result = _get_video_duration(corrupt_data)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestRealVideoFile:
|
||||
"""Tests using real video fixture file."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_video_path(self):
|
||||
"""Path to sample video fixture."""
|
||||
from pathlib import Path
|
||||
|
||||
path = Path(__file__).parent.parent.parent / "fixtures" / "sample_video.mp4"
|
||||
if not path.exists():
|
||||
pytest.skip("sample_video.mp4 fixture not found")
|
||||
return path
|
||||
|
||||
@pytest.fixture
|
||||
def sample_video_content(self, sample_video_path):
|
||||
"""Read sample video content."""
|
||||
return sample_video_path.read_bytes()
|
||||
|
||||
def test_get_video_duration_real_file(self, sample_video_content):
|
||||
"""Test duration detection with real video file."""
|
||||
try:
|
||||
import av # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("PyAV not installed")
|
||||
|
||||
duration = _get_video_duration(sample_video_content, "video/mp4")
|
||||
|
||||
assert duration is not None
|
||||
assert 4.5 <= duration <= 5.5 # ~5 seconds with tolerance
|
||||
|
||||
def test_get_video_duration_real_file_no_format_hint(self, sample_video_content):
|
||||
"""Test duration detection without format hint."""
|
||||
try:
|
||||
import av # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("PyAV not installed")
|
||||
|
||||
duration = _get_video_duration(sample_video_content)
|
||||
|
||||
assert duration is not None
|
||||
assert 4.5 <= duration <= 5.5
|
||||
|
||||
def test_validate_video_real_file_passes(self, sample_video_path):
|
||||
"""Test validating real video file within constraints."""
|
||||
try:
|
||||
import av # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("PyAV not installed")
|
||||
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4",),
|
||||
)
|
||||
file = VideoFile(source=str(sample_video_path))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_video_real_file_duration_exceeded(self, sample_video_path):
|
||||
"""Test validating real video file that exceeds duration limit."""
|
||||
try:
|
||||
import av # noqa: F401
|
||||
except ImportError:
|
||||
pytest.skip("PyAV not installed")
|
||||
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=2, # Video is ~5 seconds
|
||||
supported_formats=("video/mp4",),
|
||||
)
|
||||
file = VideoFile(source=str(sample_video_path))
|
||||
|
||||
with pytest.raises(FileValidationError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "duration" in str(exc_info.value).lower()
|
||||
assert "2s" in str(exc_info.value)
|
||||
@@ -1,312 +0,0 @@
|
||||
"""Tests for FileUrl source type and URL resolution."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.files import FileBytes, FileUrl, ImageFile
|
||||
from crewai.files.file import _normalize_source, FilePath
|
||||
from crewai.files.resolved import InlineBase64, UrlReference
|
||||
from crewai.files.resolver import FileResolver
|
||||
|
||||
|
||||
class TestFileUrl:
|
||||
"""Tests for FileUrl source type."""
|
||||
|
||||
def test_create_file_url(self):
|
||||
"""Test creating FileUrl with valid URL."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
|
||||
assert url.url == "https://example.com/image.png"
|
||||
assert url.filename is None
|
||||
|
||||
def test_create_file_url_with_filename(self):
|
||||
"""Test creating FileUrl with custom filename."""
|
||||
url = FileUrl(url="https://example.com/image.png", filename="custom.png")
|
||||
|
||||
assert url.url == "https://example.com/image.png"
|
||||
assert url.filename == "custom.png"
|
||||
|
||||
def test_invalid_url_scheme_raises(self):
|
||||
"""Test that non-http(s) URLs raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
FileUrl(url="ftp://example.com/file.txt")
|
||||
|
||||
def test_invalid_url_scheme_file_raises(self):
|
||||
"""Test that file:// URLs raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
FileUrl(url="file:///path/to/file.txt")
|
||||
|
||||
def test_http_url_valid(self):
|
||||
"""Test that HTTP URLs are valid."""
|
||||
url = FileUrl(url="http://example.com/image.jpg")
|
||||
|
||||
assert url.url == "http://example.com/image.jpg"
|
||||
|
||||
def test_https_url_valid(self):
|
||||
"""Test that HTTPS URLs are valid."""
|
||||
url = FileUrl(url="https://example.com/image.jpg")
|
||||
|
||||
assert url.url == "https://example.com/image.jpg"
|
||||
|
||||
def test_content_type_guessing_png(self):
|
||||
"""Test content type guessing for PNG files."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
|
||||
assert url.content_type == "image/png"
|
||||
|
||||
def test_content_type_guessing_jpeg(self):
|
||||
"""Test content type guessing for JPEG files."""
|
||||
url = FileUrl(url="https://example.com/photo.jpg")
|
||||
|
||||
assert url.content_type == "image/jpeg"
|
||||
|
||||
def test_content_type_guessing_pdf(self):
|
||||
"""Test content type guessing for PDF files."""
|
||||
url = FileUrl(url="https://example.com/document.pdf")
|
||||
|
||||
assert url.content_type == "application/pdf"
|
||||
|
||||
def test_content_type_guessing_with_query_params(self):
|
||||
"""Test content type guessing with URL query parameters."""
|
||||
url = FileUrl(url="https://example.com/image.png?v=123&token=abc")
|
||||
|
||||
assert url.content_type == "image/png"
|
||||
|
||||
def test_content_type_fallback_unknown(self):
|
||||
"""Test content type falls back to octet-stream for unknown extensions."""
|
||||
url = FileUrl(url="https://example.com/file.unknownext123")
|
||||
|
||||
assert url.content_type == "application/octet-stream"
|
||||
|
||||
def test_content_type_no_extension(self):
|
||||
"""Test content type for URL without extension."""
|
||||
url = FileUrl(url="https://example.com/file")
|
||||
|
||||
assert url.content_type == "application/octet-stream"
|
||||
|
||||
def test_read_fetches_content(self):
|
||||
"""Test that read() fetches content from URL."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake image content"
|
||||
mock_response.headers = {"content-type": "image/png"}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response) as mock_get:
|
||||
content = url.read()
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
"https://example.com/image.png", follow_redirects=True
|
||||
)
|
||||
assert content == b"fake image content"
|
||||
|
||||
def test_read_caches_content(self):
|
||||
"""Test that read() caches content."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake content"
|
||||
mock_response.headers = {}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response) as mock_get:
|
||||
content1 = url.read()
|
||||
content2 = url.read()
|
||||
|
||||
mock_get.assert_called_once()
|
||||
assert content1 == content2
|
||||
|
||||
def test_read_updates_content_type_from_response(self):
|
||||
"""Test that read() updates content type from response headers."""
|
||||
url = FileUrl(url="https://example.com/file")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake content"
|
||||
mock_response.headers = {"content-type": "image/webp; charset=utf-8"}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response):
|
||||
url.read()
|
||||
|
||||
assert url.content_type == "image/webp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aread_fetches_content(self):
|
||||
"""Test that aread() fetches content from URL asynchronously."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"async fake content"
|
||||
mock_response.headers = {"content-type": "image/png"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
content = await url.aread()
|
||||
|
||||
assert content == b"async fake content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aread_caches_content(self):
|
||||
"""Test that aread() caches content."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"cached content"
|
||||
mock_response.headers = {}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
content1 = await url.aread()
|
||||
content2 = await url.aread()
|
||||
|
||||
mock_client.get.assert_called_once()
|
||||
assert content1 == content2
|
||||
|
||||
|
||||
class TestNormalizeSource:
|
||||
"""Tests for _normalize_source with URL detection."""
|
||||
|
||||
def test_normalize_url_string(self):
|
||||
"""Test that URL strings are converted to FileUrl."""
|
||||
result = _normalize_source("https://example.com/image.png")
|
||||
|
||||
assert isinstance(result, FileUrl)
|
||||
assert result.url == "https://example.com/image.png"
|
||||
|
||||
def test_normalize_http_url_string(self):
|
||||
"""Test that HTTP URL strings are converted to FileUrl."""
|
||||
result = _normalize_source("http://example.com/file.pdf")
|
||||
|
||||
assert isinstance(result, FileUrl)
|
||||
assert result.url == "http://example.com/file.pdf"
|
||||
|
||||
def test_normalize_file_path_string(self, tmp_path):
|
||||
"""Test that file path strings are converted to FilePath."""
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"test content")
|
||||
|
||||
result = _normalize_source(str(test_file))
|
||||
|
||||
assert isinstance(result, FilePath)
|
||||
|
||||
def test_normalize_relative_path_is_not_url(self):
|
||||
"""Test that relative path strings are not treated as URLs."""
|
||||
result = _normalize_source("https://example.com/file.png")
|
||||
|
||||
assert isinstance(result, FileUrl)
|
||||
assert not isinstance(result, FilePath)
|
||||
|
||||
def test_normalize_file_url_passthrough(self):
|
||||
"""Test that FileUrl instances pass through unchanged."""
|
||||
original = FileUrl(url="https://example.com/image.png")
|
||||
result = _normalize_source(original)
|
||||
|
||||
assert result is original
|
||||
|
||||
|
||||
class TestResolverUrlHandling:
|
||||
"""Tests for FileResolver URL handling."""
|
||||
|
||||
def test_resolve_url_source_for_supported_provider(self):
|
||||
"""Test URL source resolves to UrlReference for supported providers."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.png"))
|
||||
|
||||
resolved = resolver.resolve(file, "anthropic")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.png"
|
||||
assert resolved.content_type == "image/png"
|
||||
|
||||
def test_resolve_url_source_openai(self):
|
||||
"""Test URL source resolves to UrlReference for OpenAI."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/photo.jpg"))
|
||||
|
||||
resolved = resolver.resolve(file, "openai")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/photo.jpg"
|
||||
|
||||
def test_resolve_url_source_gemini(self):
|
||||
"""Test URL source resolves to UrlReference for Gemini."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.webp"))
|
||||
|
||||
resolved = resolver.resolve(file, "gemini")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.webp"
|
||||
|
||||
def test_resolve_url_source_azure(self):
|
||||
"""Test URL source resolves to UrlReference for Azure."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.gif"))
|
||||
|
||||
resolved = resolver.resolve(file, "azure")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.gif"
|
||||
|
||||
def test_resolve_url_source_bedrock_fetches_content(self):
|
||||
"""Test URL source fetches content for Bedrock (unsupported URLs)."""
|
||||
resolver = FileResolver()
|
||||
file_url = FileUrl(url="https://example.com/image.png")
|
||||
file = ImageFile(source=file_url)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 50
|
||||
mock_response.headers = {"content-type": "image/png"}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response):
|
||||
resolved = resolver.resolve(file, "bedrock")
|
||||
|
||||
assert not isinstance(resolved, UrlReference)
|
||||
|
||||
def test_resolve_bytes_source_still_works(self):
|
||||
"""Test that bytes source still resolves normally."""
|
||||
resolver = FileResolver()
|
||||
minimal_png = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
|
||||
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
|
||||
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=minimal_png, filename="test.png"))
|
||||
|
||||
resolved = resolver.resolve(file, "anthropic")
|
||||
|
||||
assert isinstance(resolved, InlineBase64)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aresolve_url_source(self):
|
||||
"""Test async URL resolution for supported provider."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.png"))
|
||||
|
||||
resolved = await resolver.aresolve(file, "anthropic")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.png"
|
||||
|
||||
|
||||
class TestImageFileWithUrl:
|
||||
"""Tests for creating ImageFile with URL source."""
|
||||
|
||||
def test_image_file_from_url_string(self):
|
||||
"""Test creating ImageFile from URL string."""
|
||||
file = ImageFile(source="https://example.com/image.png")
|
||||
|
||||
assert isinstance(file.source, FileUrl)
|
||||
assert file.source.url == "https://example.com/image.png"
|
||||
|
||||
def test_image_file_from_file_url(self):
|
||||
"""Test creating ImageFile from FileUrl instance."""
|
||||
url = FileUrl(url="https://example.com/photo.jpg")
|
||||
file = ImageFile(source=url)
|
||||
|
||||
assert file.source is url
|
||||
assert file.content_type == "image/jpeg"
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Tests for resolved file types."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.files.resolved import (
|
||||
FileReference,
|
||||
InlineBase64,
|
||||
InlineBytes,
|
||||
ResolvedFile,
|
||||
UrlReference,
|
||||
)
|
||||
|
||||
|
||||
class TestInlineBase64:
|
||||
"""Tests for InlineBase64 resolved type."""
|
||||
|
||||
def test_create_inline_base64(self):
|
||||
"""Test creating InlineBase64 instance."""
|
||||
resolved = InlineBase64(
|
||||
content_type="image/png",
|
||||
data="iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==",
|
||||
)
|
||||
|
||||
assert resolved.content_type == "image/png"
|
||||
assert len(resolved.data) > 0
|
||||
|
||||
def test_inline_base64_is_resolved_file(self):
|
||||
"""Test InlineBase64 is a ResolvedFile."""
|
||||
resolved = InlineBase64(content_type="image/png", data="abc123")
|
||||
|
||||
assert isinstance(resolved, ResolvedFile)
|
||||
|
||||
def test_inline_base64_frozen(self):
|
||||
"""Test InlineBase64 is immutable."""
|
||||
resolved = InlineBase64(content_type="image/png", data="abc123")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
resolved.data = "xyz789"
|
||||
|
||||
|
||||
class TestInlineBytes:
|
||||
"""Tests for InlineBytes resolved type."""
|
||||
|
||||
def test_create_inline_bytes(self):
|
||||
"""Test creating InlineBytes instance."""
|
||||
data = b"\x89PNG\r\n\x1a\n"
|
||||
resolved = InlineBytes(
|
||||
content_type="image/png",
|
||||
data=data,
|
||||
)
|
||||
|
||||
assert resolved.content_type == "image/png"
|
||||
assert resolved.data == data
|
||||
|
||||
def test_inline_bytes_is_resolved_file(self):
|
||||
"""Test InlineBytes is a ResolvedFile."""
|
||||
resolved = InlineBytes(content_type="image/png", data=b"test")
|
||||
|
||||
assert isinstance(resolved, ResolvedFile)
|
||||
|
||||
|
||||
class TestFileReference:
|
||||
"""Tests for FileReference resolved type."""
|
||||
|
||||
def test_create_file_reference(self):
|
||||
"""Test creating FileReference instance."""
|
||||
resolved = FileReference(
|
||||
content_type="image/png",
|
||||
file_id="file-abc123",
|
||||
provider="gemini",
|
||||
)
|
||||
|
||||
assert resolved.content_type == "image/png"
|
||||
assert resolved.file_id == "file-abc123"
|
||||
assert resolved.provider == "gemini"
|
||||
assert resolved.expires_at is None
|
||||
assert resolved.file_uri is None
|
||||
|
||||
def test_file_reference_with_expiry(self):
|
||||
"""Test FileReference with expiry time."""
|
||||
expiry = datetime.now(timezone.utc)
|
||||
resolved = FileReference(
|
||||
content_type="application/pdf",
|
||||
file_id="file-xyz789",
|
||||
provider="gemini",
|
||||
expires_at=expiry,
|
||||
)
|
||||
|
||||
assert resolved.expires_at == expiry
|
||||
|
||||
def test_file_reference_with_uri(self):
|
||||
"""Test FileReference with URI."""
|
||||
resolved = FileReference(
|
||||
content_type="video/mp4",
|
||||
file_id="file-video123",
|
||||
provider="gemini",
|
||||
file_uri="https://generativelanguage.googleapis.com/v1/files/file-video123",
|
||||
)
|
||||
|
||||
assert resolved.file_uri is not None
|
||||
|
||||
def test_file_reference_is_resolved_file(self):
|
||||
"""Test FileReference is a ResolvedFile."""
|
||||
resolved = FileReference(
|
||||
content_type="image/png",
|
||||
file_id="file-123",
|
||||
provider="anthropic",
|
||||
)
|
||||
|
||||
assert isinstance(resolved, ResolvedFile)
|
||||
|
||||
|
||||
class TestUrlReference:
|
||||
"""Tests for UrlReference resolved type."""
|
||||
|
||||
def test_create_url_reference(self):
|
||||
"""Test creating UrlReference instance."""
|
||||
resolved = UrlReference(
|
||||
content_type="image/png",
|
||||
url="https://storage.googleapis.com/bucket/image.png",
|
||||
)
|
||||
|
||||
assert resolved.content_type == "image/png"
|
||||
assert resolved.url == "https://storage.googleapis.com/bucket/image.png"
|
||||
|
||||
def test_url_reference_is_resolved_file(self):
|
||||
"""Test UrlReference is a ResolvedFile."""
|
||||
resolved = UrlReference(
|
||||
content_type="image/jpeg",
|
||||
url="https://example.com/photo.jpg",
|
||||
)
|
||||
|
||||
assert isinstance(resolved, ResolvedFile)
|
||||
@@ -1,174 +0,0 @@
|
||||
"""Tests for FileResolver."""
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.files import FileBytes, ImageFile
|
||||
from crewai.files.resolved import InlineBase64, InlineBytes
|
||||
from crewai.files.resolver import (
|
||||
FileResolver,
|
||||
FileResolverConfig,
|
||||
create_resolver,
|
||||
)
|
||||
from crewai.files.upload_cache import UploadCache
|
||||
|
||||
|
||||
# Minimal valid PNG
|
||||
MINIMAL_PNG = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
|
||||
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
|
||||
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
|
||||
|
||||
class TestFileResolverConfig:
|
||||
"""Tests for FileResolverConfig."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = FileResolverConfig()
|
||||
|
||||
assert config.prefer_upload is False
|
||||
assert config.upload_threshold_bytes is None
|
||||
assert config.use_bytes_for_bedrock is True
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration values."""
|
||||
config = FileResolverConfig(
|
||||
prefer_upload=True,
|
||||
upload_threshold_bytes=1024 * 1024,
|
||||
use_bytes_for_bedrock=False,
|
||||
)
|
||||
|
||||
assert config.prefer_upload is True
|
||||
assert config.upload_threshold_bytes == 1024 * 1024
|
||||
assert config.use_bytes_for_bedrock is False
|
||||
|
||||
|
||||
class TestFileResolver:
|
||||
"""Tests for FileResolver class."""
|
||||
|
||||
def test_resolve_inline_base64(self):
|
||||
"""Test resolving file as inline base64."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
resolved = resolver.resolve(file, "openai")
|
||||
|
||||
assert isinstance(resolved, InlineBase64)
|
||||
assert resolved.content_type == "image/png"
|
||||
assert len(resolved.data) > 0
|
||||
|
||||
def test_resolve_inline_bytes_for_bedrock(self):
|
||||
"""Test resolving file as inline bytes for Bedrock."""
|
||||
config = FileResolverConfig(use_bytes_for_bedrock=True)
|
||||
resolver = FileResolver(config=config)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
resolved = resolver.resolve(file, "bedrock")
|
||||
|
||||
assert isinstance(resolved, InlineBytes)
|
||||
assert resolved.content_type == "image/png"
|
||||
assert resolved.data == MINIMAL_PNG
|
||||
|
||||
def test_resolve_files_multiple(self):
|
||||
"""Test resolving multiple files."""
|
||||
resolver = FileResolver()
|
||||
files = {
|
||||
"image1": ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png")),
|
||||
"image2": ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test2.png")),
|
||||
}
|
||||
|
||||
resolved = resolver.resolve_files(files, "openai")
|
||||
|
||||
assert len(resolved) == 2
|
||||
assert "image1" in resolved
|
||||
assert "image2" in resolved
|
||||
assert all(isinstance(r, InlineBase64) for r in resolved.values())
|
||||
|
||||
def test_resolve_with_cache(self):
|
||||
"""Test resolver uses cache."""
|
||||
cache = UploadCache()
|
||||
resolver = FileResolver(upload_cache=cache)
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
# First resolution
|
||||
resolved1 = resolver.resolve(file, "openai")
|
||||
# Second resolution (should use same base64 encoding)
|
||||
resolved2 = resolver.resolve(file, "openai")
|
||||
|
||||
assert isinstance(resolved1, InlineBase64)
|
||||
assert isinstance(resolved2, InlineBase64)
|
||||
# Data should be identical
|
||||
assert resolved1.data == resolved2.data
|
||||
|
||||
def test_clear_cache(self):
|
||||
"""Test clearing resolver cache."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
# Add something to cache manually
|
||||
cache.set(file=file, provider="gemini", file_id="test")
|
||||
|
||||
resolver = FileResolver(upload_cache=cache)
|
||||
resolver.clear_cache()
|
||||
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_get_cached_uploads(self):
|
||||
"""Test getting cached uploads from resolver."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="test-1")
|
||||
cache.set(file=file, provider="anthropic", file_id="test-2")
|
||||
|
||||
resolver = FileResolver(upload_cache=cache)
|
||||
|
||||
gemini_uploads = resolver.get_cached_uploads("gemini")
|
||||
anthropic_uploads = resolver.get_cached_uploads("anthropic")
|
||||
|
||||
assert len(gemini_uploads) == 1
|
||||
assert len(anthropic_uploads) == 1
|
||||
|
||||
def test_get_cached_uploads_empty(self):
|
||||
"""Test getting cached uploads when no cache."""
|
||||
resolver = FileResolver() # No cache
|
||||
|
||||
uploads = resolver.get_cached_uploads("gemini")
|
||||
|
||||
assert uploads == []
|
||||
|
||||
|
||||
class TestCreateResolver:
|
||||
"""Tests for create_resolver factory function."""
|
||||
|
||||
def test_create_default_resolver(self):
|
||||
"""Test creating resolver with default settings."""
|
||||
resolver = create_resolver()
|
||||
|
||||
assert resolver.config.prefer_upload is False
|
||||
assert resolver.upload_cache is not None
|
||||
|
||||
def test_create_resolver_with_options(self):
|
||||
"""Test creating resolver with custom options."""
|
||||
resolver = create_resolver(
|
||||
prefer_upload=True,
|
||||
upload_threshold_bytes=5 * 1024 * 1024,
|
||||
enable_cache=False,
|
||||
)
|
||||
|
||||
assert resolver.config.prefer_upload is True
|
||||
assert resolver.config.upload_threshold_bytes == 5 * 1024 * 1024
|
||||
assert resolver.upload_cache is None
|
||||
|
||||
def test_create_resolver_cache_enabled(self):
|
||||
"""Test resolver has cache when enabled."""
|
||||
resolver = create_resolver(enable_cache=True)
|
||||
|
||||
assert resolver.upload_cache is not None
|
||||
|
||||
def test_create_resolver_cache_disabled(self):
|
||||
"""Test resolver has no cache when disabled."""
|
||||
resolver = create_resolver(enable_cache=False)
|
||||
|
||||
assert resolver.upload_cache is None
|
||||
@@ -1,206 +0,0 @@
|
||||
"""Tests for upload cache."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.files import FileBytes, ImageFile
|
||||
from crewai.files.upload_cache import CachedUpload, UploadCache
|
||||
|
||||
|
||||
# Minimal valid PNG
|
||||
MINIMAL_PNG = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
|
||||
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
|
||||
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
|
||||
|
||||
class TestCachedUpload:
|
||||
"""Tests for CachedUpload dataclass."""
|
||||
|
||||
def test_cached_upload_creation(self):
|
||||
"""Test creating a cached upload."""
|
||||
now = datetime.now(timezone.utc)
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="gemini",
|
||||
file_uri="files/file-123",
|
||||
content_type="image/png",
|
||||
uploaded_at=now,
|
||||
expires_at=now + timedelta(hours=48),
|
||||
)
|
||||
|
||||
assert cached.file_id == "file-123"
|
||||
assert cached.provider == "gemini"
|
||||
assert cached.file_uri == "files/file-123"
|
||||
assert cached.content_type == "image/png"
|
||||
|
||||
def test_is_expired_false(self):
|
||||
"""Test is_expired returns False for non-expired upload."""
|
||||
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="gemini",
|
||||
file_uri=None,
|
||||
content_type="image/png",
|
||||
uploaded_at=datetime.now(timezone.utc),
|
||||
expires_at=future,
|
||||
)
|
||||
|
||||
assert cached.is_expired() is False
|
||||
|
||||
def test_is_expired_true(self):
|
||||
"""Test is_expired returns True for expired upload."""
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="gemini",
|
||||
file_uri=None,
|
||||
content_type="image/png",
|
||||
uploaded_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
expires_at=past,
|
||||
)
|
||||
|
||||
assert cached.is_expired() is True
|
||||
|
||||
def test_is_expired_no_expiry(self):
|
||||
"""Test is_expired returns False when no expiry set."""
|
||||
cached = CachedUpload(
|
||||
file_id="file-123",
|
||||
provider="anthropic",
|
||||
file_uri=None,
|
||||
content_type="image/png",
|
||||
uploaded_at=datetime.now(timezone.utc),
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
assert cached.is_expired() is False
|
||||
|
||||
|
||||
class TestUploadCache:
|
||||
"""Tests for UploadCache class."""
|
||||
|
||||
def test_cache_creation(self):
|
||||
"""Test creating an empty cache."""
|
||||
cache = UploadCache()
|
||||
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_set_and_get(self):
|
||||
"""Test setting and getting cached uploads."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cached = cache.set(
|
||||
file=file,
|
||||
provider="gemini",
|
||||
file_id="file-123",
|
||||
file_uri="files/file-123",
|
||||
)
|
||||
|
||||
result = cache.get(file, "gemini")
|
||||
|
||||
assert result is not None
|
||||
assert result.file_id == "file-123"
|
||||
assert result.provider == "gemini"
|
||||
|
||||
def test_get_missing(self):
|
||||
"""Test getting non-existent entry returns None."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
result = cache.get(file, "gemini")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_different_provider(self):
|
||||
"""Test getting with different provider returns None."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
|
||||
result = cache.get(file, "anthropic") # Different provider
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_remove(self):
|
||||
"""Test removing cached entry."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
removed = cache.remove(file, "gemini")
|
||||
|
||||
assert removed is True
|
||||
assert cache.get(file, "gemini") is None
|
||||
|
||||
def test_remove_missing(self):
|
||||
"""Test removing non-existent entry returns False."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
removed = cache.remove(file, "gemini")
|
||||
|
||||
assert removed is False
|
||||
|
||||
def test_remove_by_file_id(self):
|
||||
"""Test removing by file ID."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
removed = cache.remove_by_file_id("file-123", "gemini")
|
||||
|
||||
assert removed is True
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_clear_expired(self):
|
||||
"""Test clearing expired entries."""
|
||||
cache = UploadCache()
|
||||
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
|
||||
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
|
||||
|
||||
# Add one expired and one valid entry
|
||||
past = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
future = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
cache.set(file=file1, provider="gemini", file_id="expired", expires_at=past)
|
||||
cache.set(file=file2, provider="gemini", file_id="valid", expires_at=future)
|
||||
|
||||
removed = cache.clear_expired()
|
||||
|
||||
assert removed == 1
|
||||
assert len(cache) == 1
|
||||
assert cache.get(file2, "gemini") is not None
|
||||
|
||||
def test_clear(self):
|
||||
"""Test clearing all entries."""
|
||||
cache = UploadCache()
|
||||
file = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test.png"))
|
||||
|
||||
cache.set(file=file, provider="gemini", file_id="file-123")
|
||||
cache.set(file=file, provider="anthropic", file_id="file-456")
|
||||
|
||||
cleared = cache.clear()
|
||||
|
||||
assert cleared == 2
|
||||
assert len(cache) == 0
|
||||
|
||||
def test_get_all_for_provider(self):
|
||||
"""Test getting all cached uploads for a provider."""
|
||||
cache = UploadCache()
|
||||
file1 = ImageFile(source=FileBytes(data=MINIMAL_PNG, filename="test1.png"))
|
||||
file2 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"x", filename="test2.png"))
|
||||
file3 = ImageFile(source=FileBytes(data=MINIMAL_PNG + b"xx", filename="test3.png"))
|
||||
|
||||
cache.set(file=file1, provider="gemini", file_id="file-1")
|
||||
cache.set(file=file2, provider="gemini", file_id="file-2")
|
||||
cache.set(file=file3, provider="anthropic", file_id="file-3")
|
||||
|
||||
gemini_uploads = cache.get_all_for_provider("gemini")
|
||||
anthropic_uploads = cache.get_all_for_provider("anthropic")
|
||||
|
||||
assert len(gemini_uploads) == 2
|
||||
assert len(anthropic_uploads) == 1
|
||||
Reference in New Issue
Block a user