mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-23 15:18:14 +00:00
feat: add format hints to audio/video duration detection
This commit is contained in:
@@ -104,6 +104,8 @@ file-processing = [
|
||||
"python-magic>=0.4.27",
|
||||
"aiocache~=0.12.3",
|
||||
"aiofiles~=24.1.0",
|
||||
"tinytag~=1.10.0",
|
||||
"av~=13.0.0",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from crewai.files.file import (
|
||||
FileSource,
|
||||
FileSourceInput,
|
||||
FileStream,
|
||||
FileUrl,
|
||||
RawFileInput,
|
||||
)
|
||||
from crewai.files.processing import (
|
||||
@@ -103,6 +104,7 @@ __all__ = [
|
||||
"FileStream",
|
||||
"FileTooLargeError",
|
||||
"FileUploader",
|
||||
"FileUrl",
|
||||
"FileValidationError",
|
||||
"ImageConstraints",
|
||||
"ImageExtension",
|
||||
|
||||
@@ -16,6 +16,7 @@ from crewai.files.file import (
|
||||
FilePath,
|
||||
FileSource,
|
||||
FileStream,
|
||||
FileUrl,
|
||||
)
|
||||
from crewai.files.utils import is_file_source
|
||||
|
||||
@@ -29,12 +30,14 @@ class _FileSourceCoercer:
|
||||
@classmethod
|
||||
def _coerce(cls, v: Any) -> FileSource:
|
||||
"""Convert raw input to appropriate FileSource type."""
|
||||
if isinstance(v, (FilePath, FileBytes, FileStream)):
|
||||
if isinstance(v, (FilePath, FileBytes, FileStream, FileUrl)):
|
||||
return v
|
||||
if isinstance(v, str):
|
||||
if v.startswith(("http://", "https://")):
|
||||
return FileUrl(url=v)
|
||||
return FilePath(path=Path(v))
|
||||
if isinstance(v, Path):
|
||||
return FilePath(path=v)
|
||||
if isinstance(v, str):
|
||||
return FilePath(path=Path(v))
|
||||
if isinstance(v, bytes):
|
||||
return FileBytes(data=v)
|
||||
if isinstance(v, (IOBase, BinaryIO)):
|
||||
@@ -203,7 +206,7 @@ class BaseFile(ABC, BaseModel):
|
||||
TypeError: If the underlying source doesn't support async read.
|
||||
"""
|
||||
source = self._file_source
|
||||
if isinstance(source, (FilePath, FileBytes, AsyncFileStream)):
|
||||
if isinstance(source, (FilePath, FileBytes, AsyncFileStream, FileUrl)):
|
||||
return await source.aread()
|
||||
raise TypeError(f"{type(source).__name__} does not support async read")
|
||||
|
||||
|
||||
@@ -414,17 +414,84 @@ class AsyncFileStream(BaseModel):
|
||||
yield chunk
|
||||
|
||||
|
||||
FileSource = FilePath | FileBytes | FileStream | AsyncFileStream
|
||||
class FileUrl(BaseModel):
|
||||
"""File referenced by URL.
|
||||
|
||||
For providers that support URL references, the URL is passed directly.
|
||||
For providers that don't, content is fetched on demand.
|
||||
|
||||
Attributes:
|
||||
url: URL where the file can be accessed.
|
||||
filename: Optional filename (extracted from URL if not provided).
|
||||
"""
|
||||
|
||||
url: str = Field(description="URL where the file can be accessed.")
|
||||
filename: str | None = Field(default=None, description="Optional filename.")
|
||||
_content_type: str | None = PrivateAttr(default=None)
|
||||
_content: bytes | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_url(self) -> FileUrl:
|
||||
"""Validate URL format."""
|
||||
if not self.url.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Invalid URL scheme: {self.url}")
|
||||
return self
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
"""Get the content type, guessing from URL extension if not set."""
|
||||
if self._content_type is None:
|
||||
self._content_type = self._guess_content_type()
|
||||
return self._content_type
|
||||
|
||||
def _guess_content_type(self) -> str:
|
||||
"""Guess content type from URL extension."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(self.url)
|
||||
path = parsed.path
|
||||
guessed, _ = mimetypes.guess_type(path)
|
||||
return guessed or "application/octet-stream"
|
||||
|
||||
def read(self) -> bytes:
|
||||
"""Fetch content from URL (for providers that don't support URL references)."""
|
||||
if self._content is None:
|
||||
import httpx
|
||||
|
||||
response = httpx.get(self.url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
self._content = response.content
|
||||
if "content-type" in response.headers:
|
||||
self._content_type = response.headers["content-type"].split(";")[0]
|
||||
return self._content
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
"""Async fetch content from URL."""
|
||||
if self._content is None:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(self.url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
self._content = response.content
|
||||
if "content-type" in response.headers:
|
||||
self._content_type = response.headers["content-type"].split(";")[0]
|
||||
return self._content
|
||||
|
||||
|
||||
FileSource = FilePath | FileBytes | FileStream | AsyncFileStream | FileUrl
|
||||
|
||||
|
||||
def _normalize_source(value: Any) -> FileSource:
|
||||
"""Convert raw input to appropriate source type."""
|
||||
if isinstance(value, (FilePath, FileBytes, FileStream, AsyncFileStream)):
|
||||
if isinstance(value, (FilePath, FileBytes, FileStream, AsyncFileStream, FileUrl)):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
if value.startswith(("http://", "https://")):
|
||||
return FileUrl(url=value)
|
||||
return FilePath(path=Path(value))
|
||||
if isinstance(value, Path):
|
||||
return FilePath(path=value)
|
||||
if isinstance(value, str):
|
||||
return FilePath(path=Path(value))
|
||||
if isinstance(value, bytes):
|
||||
return FileBytes(data=value)
|
||||
if isinstance(value, AsyncReadable):
|
||||
|
||||
@@ -148,6 +148,7 @@ class ProviderConstraints:
|
||||
general_max_size_bytes: Maximum size for any file type.
|
||||
supports_file_upload: Whether the provider supports file upload APIs.
|
||||
file_upload_threshold_bytes: Size threshold above which to use file upload.
|
||||
supports_url_references: Whether the provider supports URL-based file references.
|
||||
"""
|
||||
|
||||
name: ProviderName
|
||||
@@ -158,21 +159,24 @@ class ProviderConstraints:
|
||||
general_max_size_bytes: int | None = None
|
||||
supports_file_upload: bool = False
|
||||
file_upload_threshold_bytes: int | None = None
|
||||
supports_url_references: bool = False
|
||||
|
||||
|
||||
ANTHROPIC_CONSTRAINTS = ProviderConstraints(
|
||||
name="anthropic",
|
||||
image=ImageConstraints(
|
||||
max_size_bytes=5_242_880,
|
||||
max_size_bytes=5_242_880, # 5 MB per image
|
||||
max_width=8000,
|
||||
max_height=8000,
|
||||
max_images_per_request=100,
|
||||
),
|
||||
pdf=PDFConstraints(
|
||||
max_size_bytes=31_457_280,
|
||||
max_size_bytes=33_554_432, # 32 MB request size limit
|
||||
max_pages=100,
|
||||
),
|
||||
supports_file_upload=True,
|
||||
file_upload_threshold_bytes=5_242_880,
|
||||
supports_url_references=True,
|
||||
)
|
||||
|
||||
OPENAI_CONSTRAINTS = ProviderConstraints(
|
||||
@@ -181,8 +185,13 @@ OPENAI_CONSTRAINTS = ProviderConstraints(
|
||||
max_size_bytes=20_971_520,
|
||||
max_images_per_request=10,
|
||||
),
|
||||
audio=AudioConstraints(
|
||||
max_size_bytes=26_214_400, # 25 MB - whisper limit
|
||||
max_duration_seconds=1500, # 25 minutes, arbitrary-ish, this is from the transcriptions limit
|
||||
),
|
||||
supports_file_upload=True,
|
||||
file_upload_threshold_bytes=5_242_880,
|
||||
supports_url_references=True,
|
||||
)
|
||||
|
||||
GEMINI_CONSTRAINTS = ProviderConstraints(
|
||||
@@ -196,14 +205,17 @@ GEMINI_CONSTRAINTS = ProviderConstraints(
|
||||
),
|
||||
audio=AudioConstraints(
|
||||
max_size_bytes=104_857_600,
|
||||
max_duration_seconds=34200, # 9.5 hours
|
||||
supported_formats=GEMINI_AUDIO_FORMATS,
|
||||
),
|
||||
video=VideoConstraints(
|
||||
max_size_bytes=2_147_483_648,
|
||||
max_duration_seconds=3600, # 1 hour at default resolution
|
||||
supported_formats=GEMINI_VIDEO_FORMATS,
|
||||
),
|
||||
supports_file_upload=True,
|
||||
file_upload_threshold_bytes=20_971_520,
|
||||
supports_url_references=True,
|
||||
)
|
||||
|
||||
BEDROCK_CONSTRAINTS = ProviderConstraints(
|
||||
@@ -225,6 +237,11 @@ AZURE_CONSTRAINTS = ProviderConstraints(
|
||||
max_size_bytes=20_971_520,
|
||||
max_images_per_request=10,
|
||||
),
|
||||
audio=AudioConstraints(
|
||||
max_size_bytes=26_214_400, # 25 MB - same as openai
|
||||
max_duration_seconds=1500, # 25 minutes - same as openai
|
||||
),
|
||||
supports_url_references=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ def _get_image_dimensions(content: bytes) -> tuple[int, int] | None:
|
||||
|
||||
with Image.open(io.BytesIO(content)) as img:
|
||||
width, height = img.size
|
||||
return (int(width), int(height))
|
||||
return int(width), int(height)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Pillow not installed - cannot validate image dimensions. "
|
||||
@@ -74,6 +74,81 @@ def _get_pdf_page_count(content: bytes) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def _get_audio_duration(content: bytes, filename: str | None = None) -> float | None:
|
||||
"""Get audio duration in seconds using tinytag if available.
|
||||
|
||||
Args:
|
||||
content: Raw audio bytes.
|
||||
filename: Optional filename for format detection hint.
|
||||
|
||||
Returns:
|
||||
Duration in seconds or None if tinytag unavailable.
|
||||
"""
|
||||
try:
|
||||
from tinytag import TinyTag # type: ignore[import-untyped]
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"tinytag not installed - cannot validate audio duration. "
|
||||
"Install with: pip install tinytag"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
tag = TinyTag.get(file_obj=io.BytesIO(content), filename=filename)
|
||||
duration: float | None = tag.duration
|
||||
return duration
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not determine audio duration: {e}")
|
||||
return None
|
||||
|
||||
|
||||
_VIDEO_FORMAT_MAP: dict[str, str] = {
|
||||
"video/mp4": "mp4",
|
||||
"video/webm": "webm",
|
||||
"video/x-matroska": "matroska",
|
||||
"video/quicktime": "mov",
|
||||
"video/x-msvideo": "avi",
|
||||
"video/x-flv": "flv",
|
||||
}
|
||||
|
||||
|
||||
def _get_video_duration(
|
||||
content: bytes, content_type: str | None = None
|
||||
) -> float | None:
|
||||
"""Get video duration in seconds using av if available.
|
||||
|
||||
Args:
|
||||
content: Raw video bytes.
|
||||
content_type: Optional MIME type for format detection hint.
|
||||
|
||||
Returns:
|
||||
Duration in seconds or None if av unavailable.
|
||||
"""
|
||||
try:
|
||||
import av
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"av (PyAV) not installed - cannot validate video duration. "
|
||||
"Install with: pip install av"
|
||||
)
|
||||
return None
|
||||
|
||||
format_hint = _VIDEO_FORMAT_MAP.get(content_type) if content_type else None
|
||||
|
||||
try:
|
||||
container = av.open(io.BytesIO(content), format=format_hint) # type: ignore[attr-defined]
|
||||
try:
|
||||
duration = getattr(container, "duration", None)
|
||||
if duration is None:
|
||||
return None
|
||||
return float(duration) / 1_000_000
|
||||
finally:
|
||||
container.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not determine video duration: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _format_size(size_bytes: int) -> str:
|
||||
"""Format byte size to human-readable string."""
|
||||
if size_bytes >= 1024 * 1024 * 1024:
|
||||
@@ -273,14 +348,17 @@ def validate_audio(
|
||||
|
||||
Raises:
|
||||
FileTooLargeError: If the file exceeds size limits.
|
||||
FileValidationError: If the file exceeds duration limits.
|
||||
UnsupportedFileTypeError: If the format is not supported.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
file_size = len(file.read())
|
||||
content = file.read()
|
||||
file_size = len(content)
|
||||
filename = file.filename
|
||||
|
||||
_validate_size(
|
||||
"Audio",
|
||||
file.filename,
|
||||
filename,
|
||||
file_size,
|
||||
constraints.max_size_bytes,
|
||||
errors,
|
||||
@@ -288,13 +366,24 @@ def validate_audio(
|
||||
)
|
||||
_validate_format(
|
||||
"Audio",
|
||||
file.filename,
|
||||
filename,
|
||||
file.content_type,
|
||||
constraints.supported_formats,
|
||||
errors,
|
||||
raise_on_error,
|
||||
)
|
||||
|
||||
if constraints.max_duration_seconds is not None:
|
||||
duration = _get_audio_duration(content, filename)
|
||||
if duration is not None and duration > constraints.max_duration_seconds:
|
||||
msg = (
|
||||
f"Audio '{filename}' duration ({duration:.1f}s) exceeds "
|
||||
f"maximum ({constraints.max_duration_seconds}s)"
|
||||
)
|
||||
errors.append(msg)
|
||||
if raise_on_error:
|
||||
raise FileValidationError(msg, file_name=filename)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@@ -316,14 +405,17 @@ def validate_video(
|
||||
|
||||
Raises:
|
||||
FileTooLargeError: If the file exceeds size limits.
|
||||
FileValidationError: If the file exceeds duration limits.
|
||||
UnsupportedFileTypeError: If the format is not supported.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
file_size = len(file.read())
|
||||
content = file.read()
|
||||
file_size = len(content)
|
||||
filename = file.filename
|
||||
|
||||
_validate_size(
|
||||
"Video",
|
||||
file.filename,
|
||||
filename,
|
||||
file_size,
|
||||
constraints.max_size_bytes,
|
||||
errors,
|
||||
@@ -331,13 +423,24 @@ def validate_video(
|
||||
)
|
||||
_validate_format(
|
||||
"Video",
|
||||
file.filename,
|
||||
filename,
|
||||
file.content_type,
|
||||
constraints.supported_formats,
|
||||
errors,
|
||||
raise_on_error,
|
||||
)
|
||||
|
||||
if constraints.max_duration_seconds is not None:
|
||||
duration = _get_video_duration(content, file.content_type)
|
||||
if duration is not None and duration > constraints.max_duration_seconds:
|
||||
msg = (
|
||||
f"Video '{filename}' duration ({duration:.1f}s) exceeds "
|
||||
f"maximum ({constraints.max_duration_seconds}s)"
|
||||
)
|
||||
errors.append(msg)
|
||||
if raise_on_error:
|
||||
raise FileValidationError(msg, file_name=filename)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
|
||||
from crewai.files.constants import UPLOAD_MAX_RETRIES, UPLOAD_RETRY_DELAY_BASE
|
||||
from crewai.files.content_types import FileInput
|
||||
from crewai.files.file import FileUrl
|
||||
from crewai.files.metrics import measure_operation
|
||||
from crewai.files.processing.constraints import (
|
||||
AudioConstraints,
|
||||
@@ -22,10 +23,12 @@ from crewai.files.resolved import (
|
||||
InlineBase64,
|
||||
InlineBytes,
|
||||
ResolvedFile,
|
||||
UrlReference,
|
||||
)
|
||||
from crewai.files.upload_cache import CachedUpload, UploadCache
|
||||
from crewai.files.uploaders import UploadResult, get_uploader
|
||||
from crewai.files.uploaders.base import FileUploader
|
||||
from crewai.files.uploaders.factory import ProviderType
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -102,7 +105,49 @@ class FileResolver:
|
||||
content_type=file.content_type,
|
||||
)
|
||||
|
||||
def resolve(self, file: FileInput, provider: str) -> ResolvedFile:
|
||||
@staticmethod
|
||||
def _is_url_source(file: FileInput) -> bool:
|
||||
"""Check if file source is a URL.
|
||||
|
||||
Args:
|
||||
file: The file to check.
|
||||
|
||||
Returns:
|
||||
True if the file source is a FileUrl, False otherwise.
|
||||
"""
|
||||
return isinstance(file._file_source, FileUrl)
|
||||
|
||||
@staticmethod
|
||||
def _supports_url(constraints: ProviderConstraints | None) -> bool:
|
||||
"""Check if provider supports URL references.
|
||||
|
||||
Args:
|
||||
constraints: Provider constraints.
|
||||
|
||||
Returns:
|
||||
True if the provider supports URL references, False otherwise.
|
||||
"""
|
||||
return constraints is not None and constraints.supports_url_references
|
||||
|
||||
@staticmethod
|
||||
def _resolve_as_url(file: FileInput) -> UrlReference:
|
||||
"""Resolve a URL source as UrlReference.
|
||||
|
||||
Args:
|
||||
file: The file with URL source.
|
||||
|
||||
Returns:
|
||||
UrlReference with the URL and content type.
|
||||
"""
|
||||
source = file._file_source
|
||||
if not isinstance(source, FileUrl):
|
||||
raise TypeError(f"Expected FileUrl source, got {type(source).__name__}")
|
||||
return UrlReference(
|
||||
content_type=file.content_type,
|
||||
url=source.url,
|
||||
)
|
||||
|
||||
def resolve(self, file: FileInput, provider: ProviderType) -> ResolvedFile:
|
||||
"""Resolve a file to its delivery format for a provider.
|
||||
|
||||
Args:
|
||||
@@ -112,25 +157,26 @@ class FileResolver:
|
||||
Returns:
|
||||
ResolvedFile representing the appropriate delivery format.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
constraints = get_constraints_for_provider(provider)
|
||||
|
||||
if self._is_url_source(file) and self._supports_url(constraints):
|
||||
return self._resolve_as_url(file)
|
||||
|
||||
context = self._build_file_context(file)
|
||||
|
||||
should_upload = self._should_upload(
|
||||
file, provider_lower, constraints, context.size
|
||||
)
|
||||
should_upload = self._should_upload(file, provider, constraints, context.size)
|
||||
|
||||
if should_upload:
|
||||
resolved = self._resolve_via_upload(file, provider_lower, context)
|
||||
resolved = self._resolve_via_upload(file, provider, context)
|
||||
if resolved is not None:
|
||||
return resolved
|
||||
|
||||
return self._resolve_inline(file, provider_lower, context)
|
||||
return self._resolve_inline(file, provider, context)
|
||||
|
||||
def resolve_files(
|
||||
self,
|
||||
files: dict[str, FileInput],
|
||||
provider: str,
|
||||
provider: ProviderType,
|
||||
) -> dict[str, ResolvedFile]:
|
||||
"""Resolve multiple files for a provider.
|
||||
|
||||
@@ -220,7 +266,7 @@ class FileResolver:
|
||||
def _resolve_via_upload(
|
||||
self,
|
||||
file: FileInput,
|
||||
provider: str,
|
||||
provider: ProviderType,
|
||||
context: FileContext,
|
||||
) -> ResolvedFile | None:
|
||||
"""Resolve a file by uploading it.
|
||||
@@ -367,7 +413,7 @@ class FileResolver:
|
||||
data=encoded,
|
||||
)
|
||||
|
||||
async def aresolve(self, file: FileInput, provider: str) -> ResolvedFile:
|
||||
async def aresolve(self, file: FileInput, provider: ProviderType) -> ResolvedFile:
|
||||
"""Async resolve a file to its delivery format for a provider.
|
||||
|
||||
Args:
|
||||
@@ -377,25 +423,26 @@ class FileResolver:
|
||||
Returns:
|
||||
ResolvedFile representing the appropriate delivery format.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
constraints = get_constraints_for_provider(provider)
|
||||
|
||||
if self._is_url_source(file) and self._supports_url(constraints):
|
||||
return self._resolve_as_url(file)
|
||||
|
||||
context = self._build_file_context(file)
|
||||
|
||||
should_upload = self._should_upload(
|
||||
file, provider_lower, constraints, context.size
|
||||
)
|
||||
should_upload = self._should_upload(file, provider, constraints, context.size)
|
||||
|
||||
if should_upload:
|
||||
resolved = await self._aresolve_via_upload(file, provider_lower, context)
|
||||
resolved = await self._aresolve_via_upload(file, provider, context)
|
||||
if resolved is not None:
|
||||
return resolved
|
||||
|
||||
return self._resolve_inline(file, provider_lower, context)
|
||||
return self._resolve_inline(file, provider, context)
|
||||
|
||||
async def aresolve_files(
|
||||
self,
|
||||
files: dict[str, FileInput],
|
||||
provider: str,
|
||||
provider: ProviderType,
|
||||
max_concurrency: int = 10,
|
||||
) -> dict[str, ResolvedFile]:
|
||||
"""Async resolve multiple files in parallel.
|
||||
@@ -434,7 +481,7 @@ class FileResolver:
|
||||
async def _aresolve_via_upload(
|
||||
self,
|
||||
file: FileInput,
|
||||
provider: str,
|
||||
provider: ProviderType,
|
||||
context: FileContext,
|
||||
) -> ResolvedFile | None:
|
||||
"""Async resolve a file by uploading it.
|
||||
@@ -552,7 +599,7 @@ class FileResolver:
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_uploader(self, provider: str) -> FileUploader | None:
|
||||
def _get_uploader(self, provider: ProviderType) -> FileUploader | None:
|
||||
"""Get or create an uploader for a provider.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -15,9 +15,9 @@ if TYPE_CHECKING:
|
||||
|
||||
def is_file_source(v: object) -> TypeIs[FileSource]:
|
||||
"""Type guard to narrow input to FileSource."""
|
||||
from crewai.files.file import FileBytes, FilePath, FileStream
|
||||
from crewai.files.file import FileBytes, FilePath, FileStream, FileUrl
|
||||
|
||||
return isinstance(v, (FilePath, FileBytes, FileStream))
|
||||
return isinstance(v, (FilePath, FileBytes, FileStream, FileUrl))
|
||||
|
||||
|
||||
def wrap_file_source(source: FileSource) -> FileInput:
|
||||
@@ -62,7 +62,7 @@ def normalize_input_files(
|
||||
Dictionary mapping names to FileInput wrappers.
|
||||
"""
|
||||
from crewai.files.content_types import BaseFile
|
||||
from crewai.files.file import FileBytes, FilePath, FileStream
|
||||
from crewai.files.file import FileBytes, FilePath, FileStream, FileUrl
|
||||
|
||||
result: dict[str, FileInput] = {}
|
||||
|
||||
@@ -74,13 +74,16 @@ def normalize_input_files(
|
||||
result[name] = item
|
||||
continue
|
||||
|
||||
file_source: FilePath | FileBytes | FileStream
|
||||
if isinstance(item, (FilePath, FileBytes, FileStream)):
|
||||
file_source: FilePath | FileBytes | FileStream | FileUrl
|
||||
if isinstance(item, (FilePath, FileBytes, FileStream, FileUrl)):
|
||||
file_source = item
|
||||
elif isinstance(item, Path):
|
||||
file_source = FilePath(path=item)
|
||||
elif isinstance(item, str):
|
||||
file_source = FilePath(path=Path(item))
|
||||
if item.startswith(("http://", "https://")):
|
||||
file_source = FileUrl(url=item)
|
||||
else:
|
||||
file_source = FilePath(path=Path(item))
|
||||
elif isinstance(item, (bytes, memoryview)):
|
||||
file_source = FileBytes(data=bytes(item))
|
||||
else:
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
"""Tests for file validators."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.files import FileBytes, ImageFile, PDFFile, TextFile
|
||||
from crewai.files import AudioFile, FileBytes, ImageFile, PDFFile, TextFile, VideoFile
|
||||
from crewai.files.processing.constraints import (
|
||||
ANTHROPIC_CONSTRAINTS,
|
||||
AudioConstraints,
|
||||
ImageConstraints,
|
||||
PDFConstraints,
|
||||
ProviderConstraints,
|
||||
VideoConstraints,
|
||||
)
|
||||
from crewai.files.processing.exceptions import (
|
||||
FileTooLargeError,
|
||||
@@ -15,10 +19,14 @@ from crewai.files.processing.exceptions import (
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from crewai.files.processing.validators import (
|
||||
_get_audio_duration,
|
||||
_get_video_duration,
|
||||
validate_audio,
|
||||
validate_file,
|
||||
validate_image,
|
||||
validate_pdf,
|
||||
validate_text,
|
||||
validate_video,
|
||||
)
|
||||
|
||||
|
||||
@@ -206,3 +214,281 @@ class TestValidateFile:
|
||||
validate_file(file, constraints)
|
||||
|
||||
assert "does not support PDFs" in str(exc_info.value)
|
||||
|
||||
|
||||
# Minimal audio bytes for testing (not a valid audio file, used for mocked tests)
|
||||
MINIMAL_AUDIO = b"\x00" * 100
|
||||
|
||||
# Minimal video bytes for testing (not a valid video file, used for mocked tests)
|
||||
MINIMAL_VIDEO = b"\x00" * 100
|
||||
|
||||
# Fallback content type when python-magic cannot detect
|
||||
FALLBACK_CONTENT_TYPE = "application/octet-stream"
|
||||
|
||||
|
||||
class TestValidateAudio:
|
||||
"""Tests for validate_audio function and audio duration validation."""
|
||||
|
||||
def test_validate_valid_audio(self):
|
||||
"""Test validating a valid audio file within constraints."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_audio_too_large(self):
|
||||
"""Test validating an audio file that exceeds size limit."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10, # Very small limit
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
with pytest.raises(FileTooLargeError) as exc_info:
|
||||
validate_audio(file, constraints)
|
||||
|
||||
assert "exceeds" in str(exc_info.value)
|
||||
assert exc_info.value.file_name == "test.mp3"
|
||||
|
||||
def test_validate_audio_unsupported_format(self):
|
||||
"""Test validating an audio file with unsupported format."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("audio/wav",), # Only WAV
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_audio(file, constraints)
|
||||
|
||||
assert "not supported" in str(exc_info.value)
|
||||
|
||||
@patch("crewai.files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_passes(self, mock_get_duration):
|
||||
"""Test validating audio when duration is under limit."""
|
||||
mock_get_duration.return_value = 30.0
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_called_once()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_fails(self, mock_get_duration):
|
||||
"""Test validating audio when duration exceeds limit."""
|
||||
mock_get_duration.return_value = 120.5
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
with pytest.raises(FileValidationError) as exc_info:
|
||||
validate_audio(file, constraints)
|
||||
|
||||
assert "duration" in str(exc_info.value).lower()
|
||||
assert "120.5s" in str(exc_info.value)
|
||||
assert "60s" in str(exc_info.value)
|
||||
|
||||
@patch("crewai.files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_no_raise(self, mock_get_duration):
|
||||
"""Test audio duration validation with raise_on_error=False."""
|
||||
mock_get_duration.return_value = 120.5
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert "duration" in errors[0].lower()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_none_skips(self, mock_get_duration):
|
||||
"""Test that duration validation is skipped when max_duration_seconds is None."""
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=None,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_not_called()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_audio_duration")
|
||||
def test_validate_audio_duration_detection_returns_none(self, mock_get_duration):
|
||||
"""Test that validation passes when duration detection returns None."""
|
||||
mock_get_duration.return_value = None
|
||||
constraints = AudioConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("audio/mp3", "audio/mpeg", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = AudioFile(source=FileBytes(data=MINIMAL_AUDIO, filename="test.mp3"))
|
||||
|
||||
errors = validate_audio(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
class TestValidateVideo:
|
||||
"""Tests for validate_video function and video duration validation."""
|
||||
|
||||
def test_validate_valid_video(self):
|
||||
"""Test validating a valid video file within constraints."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_video_too_large(self):
|
||||
"""Test validating a video file that exceeds size limit."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10, # Very small limit
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
with pytest.raises(FileTooLargeError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "exceeds" in str(exc_info.value)
|
||||
assert exc_info.value.file_name == "test.mp4"
|
||||
|
||||
def test_validate_video_unsupported_format(self):
|
||||
"""Test validating a video file with unsupported format."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
supported_formats=("video/webm",), # Only WebM
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "not supported" in str(exc_info.value)
|
||||
|
||||
@patch("crewai.files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_passes(self, mock_get_duration):
|
||||
"""Test validating video when duration is under limit."""
|
||||
mock_get_duration.return_value = 30.0
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_called_once()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_fails(self, mock_get_duration):
|
||||
"""Test validating video when duration exceeds limit."""
|
||||
mock_get_duration.return_value = 180.0
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
with pytest.raises(FileValidationError) as exc_info:
|
||||
validate_video(file, constraints)
|
||||
|
||||
assert "duration" in str(exc_info.value).lower()
|
||||
assert "180.0s" in str(exc_info.value)
|
||||
assert "60s" in str(exc_info.value)
|
||||
|
||||
@patch("crewai.files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_no_raise(self, mock_get_duration):
|
||||
"""Test video duration validation with raise_on_error=False."""
|
||||
mock_get_duration.return_value = 180.0
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert "duration" in errors[0].lower()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_none_skips(self, mock_get_duration):
|
||||
"""Test that duration validation is skipped when max_duration_seconds is None."""
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=None,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
mock_get_duration.assert_not_called()
|
||||
|
||||
@patch("crewai.files.processing.validators._get_video_duration")
|
||||
def test_validate_video_duration_detection_returns_none(self, mock_get_duration):
|
||||
"""Test that validation passes when duration detection returns None."""
|
||||
mock_get_duration.return_value = None
|
||||
constraints = VideoConstraints(
|
||||
max_size_bytes=10 * 1024 * 1024,
|
||||
max_duration_seconds=60,
|
||||
supported_formats=("video/mp4", FALLBACK_CONTENT_TYPE),
|
||||
)
|
||||
file = VideoFile(source=FileBytes(data=MINIMAL_VIDEO, filename="test.mp4"))
|
||||
|
||||
errors = validate_video(file, constraints, raise_on_error=False)
|
||||
|
||||
assert len(errors) == 0
|
||||
|
||||
|
||||
class TestGetAudioDuration:
|
||||
"""Tests for _get_audio_duration helper function."""
|
||||
|
||||
def test_get_audio_duration_corrupt_file(self):
|
||||
"""Test handling of corrupt audio data."""
|
||||
corrupt_data = b"not valid audio data at all"
|
||||
result = _get_audio_duration(corrupt_data)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetVideoDuration:
|
||||
"""Tests for _get_video_duration helper function."""
|
||||
|
||||
def test_get_video_duration_corrupt_file(self):
|
||||
"""Test handling of corrupt video data."""
|
||||
corrupt_data = b"not valid video data at all"
|
||||
result = _get_video_duration(corrupt_data)
|
||||
|
||||
assert result is None
|
||||
|
||||
312
lib/crewai/tests/files/test_file_url.py
Normal file
312
lib/crewai/tests/files/test_file_url.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""Tests for FileUrl source type and URL resolution."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.files import FileBytes, FileUrl, ImageFile
|
||||
from crewai.files.file import _normalize_source, FilePath
|
||||
from crewai.files.resolved import InlineBase64, UrlReference
|
||||
from crewai.files.resolver import FileResolver
|
||||
|
||||
|
||||
class TestFileUrl:
|
||||
"""Tests for FileUrl source type."""
|
||||
|
||||
def test_create_file_url(self):
|
||||
"""Test creating FileUrl with valid URL."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
|
||||
assert url.url == "https://example.com/image.png"
|
||||
assert url.filename is None
|
||||
|
||||
def test_create_file_url_with_filename(self):
|
||||
"""Test creating FileUrl with custom filename."""
|
||||
url = FileUrl(url="https://example.com/image.png", filename="custom.png")
|
||||
|
||||
assert url.url == "https://example.com/image.png"
|
||||
assert url.filename == "custom.png"
|
||||
|
||||
def test_invalid_url_scheme_raises(self):
|
||||
"""Test that non-http(s) URLs raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
FileUrl(url="ftp://example.com/file.txt")
|
||||
|
||||
def test_invalid_url_scheme_file_raises(self):
|
||||
"""Test that file:// URLs raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid URL scheme"):
|
||||
FileUrl(url="file:///path/to/file.txt")
|
||||
|
||||
def test_http_url_valid(self):
|
||||
"""Test that HTTP URLs are valid."""
|
||||
url = FileUrl(url="http://example.com/image.jpg")
|
||||
|
||||
assert url.url == "http://example.com/image.jpg"
|
||||
|
||||
def test_https_url_valid(self):
|
||||
"""Test that HTTPS URLs are valid."""
|
||||
url = FileUrl(url="https://example.com/image.jpg")
|
||||
|
||||
assert url.url == "https://example.com/image.jpg"
|
||||
|
||||
def test_content_type_guessing_png(self):
|
||||
"""Test content type guessing for PNG files."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
|
||||
assert url.content_type == "image/png"
|
||||
|
||||
def test_content_type_guessing_jpeg(self):
|
||||
"""Test content type guessing for JPEG files."""
|
||||
url = FileUrl(url="https://example.com/photo.jpg")
|
||||
|
||||
assert url.content_type == "image/jpeg"
|
||||
|
||||
def test_content_type_guessing_pdf(self):
|
||||
"""Test content type guessing for PDF files."""
|
||||
url = FileUrl(url="https://example.com/document.pdf")
|
||||
|
||||
assert url.content_type == "application/pdf"
|
||||
|
||||
def test_content_type_guessing_with_query_params(self):
|
||||
"""Test content type guessing with URL query parameters."""
|
||||
url = FileUrl(url="https://example.com/image.png?v=123&token=abc")
|
||||
|
||||
assert url.content_type == "image/png"
|
||||
|
||||
def test_content_type_fallback_unknown(self):
|
||||
"""Test content type falls back to octet-stream for unknown extensions."""
|
||||
url = FileUrl(url="https://example.com/file.unknownext123")
|
||||
|
||||
assert url.content_type == "application/octet-stream"
|
||||
|
||||
def test_content_type_no_extension(self):
|
||||
"""Test content type for URL without extension."""
|
||||
url = FileUrl(url="https://example.com/file")
|
||||
|
||||
assert url.content_type == "application/octet-stream"
|
||||
|
||||
def test_read_fetches_content(self):
|
||||
"""Test that read() fetches content from URL."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake image content"
|
||||
mock_response.headers = {"content-type": "image/png"}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response) as mock_get:
|
||||
content = url.read()
|
||||
|
||||
mock_get.assert_called_once_with(
|
||||
"https://example.com/image.png", follow_redirects=True
|
||||
)
|
||||
assert content == b"fake image content"
|
||||
|
||||
def test_read_caches_content(self):
|
||||
"""Test that read() caches content."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake content"
|
||||
mock_response.headers = {}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response) as mock_get:
|
||||
content1 = url.read()
|
||||
content2 = url.read()
|
||||
|
||||
mock_get.assert_called_once()
|
||||
assert content1 == content2
|
||||
|
||||
def test_read_updates_content_type_from_response(self):
|
||||
"""Test that read() updates content type from response headers."""
|
||||
url = FileUrl(url="https://example.com/file")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"fake content"
|
||||
mock_response.headers = {"content-type": "image/webp; charset=utf-8"}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response):
|
||||
url.read()
|
||||
|
||||
assert url.content_type == "image/webp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aread_fetches_content(self):
|
||||
"""Test that aread() fetches content from URL asynchronously."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"async fake content"
|
||||
mock_response.headers = {"content-type": "image/png"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
content = await url.aread()
|
||||
|
||||
assert content == b"async fake content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aread_caches_content(self):
|
||||
"""Test that aread() caches content."""
|
||||
url = FileUrl(url="https://example.com/image.png")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"cached content"
|
||||
mock_response.headers = {}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
content1 = await url.aread()
|
||||
content2 = await url.aread()
|
||||
|
||||
mock_client.get.assert_called_once()
|
||||
assert content1 == content2
|
||||
|
||||
|
||||
class TestNormalizeSource:
|
||||
"""Tests for _normalize_source with URL detection."""
|
||||
|
||||
def test_normalize_url_string(self):
|
||||
"""Test that URL strings are converted to FileUrl."""
|
||||
result = _normalize_source("https://example.com/image.png")
|
||||
|
||||
assert isinstance(result, FileUrl)
|
||||
assert result.url == "https://example.com/image.png"
|
||||
|
||||
def test_normalize_http_url_string(self):
|
||||
"""Test that HTTP URL strings are converted to FileUrl."""
|
||||
result = _normalize_source("http://example.com/file.pdf")
|
||||
|
||||
assert isinstance(result, FileUrl)
|
||||
assert result.url == "http://example.com/file.pdf"
|
||||
|
||||
def test_normalize_file_path_string(self, tmp_path):
|
||||
"""Test that file path strings are converted to FilePath."""
|
||||
test_file = tmp_path / "test.png"
|
||||
test_file.write_bytes(b"test content")
|
||||
|
||||
result = _normalize_source(str(test_file))
|
||||
|
||||
assert isinstance(result, FilePath)
|
||||
|
||||
def test_normalize_relative_path_is_not_url(self):
|
||||
"""Test that relative path strings are not treated as URLs."""
|
||||
result = _normalize_source("https://example.com/file.png")
|
||||
|
||||
assert isinstance(result, FileUrl)
|
||||
assert not isinstance(result, FilePath)
|
||||
|
||||
def test_normalize_file_url_passthrough(self):
|
||||
"""Test that FileUrl instances pass through unchanged."""
|
||||
original = FileUrl(url="https://example.com/image.png")
|
||||
result = _normalize_source(original)
|
||||
|
||||
assert result is original
|
||||
|
||||
|
||||
class TestResolverUrlHandling:
|
||||
"""Tests for FileResolver URL handling."""
|
||||
|
||||
def test_resolve_url_source_for_supported_provider(self):
|
||||
"""Test URL source resolves to UrlReference for supported providers."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.png"))
|
||||
|
||||
resolved = resolver.resolve(file, "anthropic")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.png"
|
||||
assert resolved.content_type == "image/png"
|
||||
|
||||
def test_resolve_url_source_openai(self):
|
||||
"""Test URL source resolves to UrlReference for OpenAI."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/photo.jpg"))
|
||||
|
||||
resolved = resolver.resolve(file, "openai")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/photo.jpg"
|
||||
|
||||
def test_resolve_url_source_gemini(self):
|
||||
"""Test URL source resolves to UrlReference for Gemini."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.webp"))
|
||||
|
||||
resolved = resolver.resolve(file, "gemini")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.webp"
|
||||
|
||||
def test_resolve_url_source_azure(self):
|
||||
"""Test URL source resolves to UrlReference for Azure."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.gif"))
|
||||
|
||||
resolved = resolver.resolve(file, "azure")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.gif"
|
||||
|
||||
def test_resolve_url_source_bedrock_fetches_content(self):
|
||||
"""Test URL source fetches content for Bedrock (unsupported URLs)."""
|
||||
resolver = FileResolver()
|
||||
file_url = FileUrl(url="https://example.com/image.png")
|
||||
file = ImageFile(source=file_url)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 50
|
||||
mock_response.headers = {"content-type": "image/png"}
|
||||
|
||||
with patch("httpx.get", return_value=mock_response):
|
||||
resolved = resolver.resolve(file, "bedrock")
|
||||
|
||||
assert not isinstance(resolved, UrlReference)
|
||||
|
||||
def test_resolve_bytes_source_still_works(self):
|
||||
"""Test that bytes source still resolves normally."""
|
||||
resolver = FileResolver()
|
||||
minimal_png = (
|
||||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x08\x00\x00\x00\x08"
|
||||
b"\x01\x00\x00\x00\x00\xf9Y\xab\xcd\x00\x00\x00\nIDATx\x9cc`\x00\x00"
|
||||
b"\x00\x02\x00\x01\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
file = ImageFile(source=FileBytes(data=minimal_png, filename="test.png"))
|
||||
|
||||
resolved = resolver.resolve(file, "anthropic")
|
||||
|
||||
assert isinstance(resolved, InlineBase64)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aresolve_url_source(self):
|
||||
"""Test async URL resolution for supported provider."""
|
||||
resolver = FileResolver()
|
||||
file = ImageFile(source=FileUrl(url="https://example.com/image.png"))
|
||||
|
||||
resolved = await resolver.aresolve(file, "anthropic")
|
||||
|
||||
assert isinstance(resolved, UrlReference)
|
||||
assert resolved.url == "https://example.com/image.png"
|
||||
|
||||
|
||||
class TestImageFileWithUrl:
|
||||
"""Tests for creating ImageFile with URL source."""
|
||||
|
||||
def test_image_file_from_url_string(self):
|
||||
"""Test creating ImageFile from URL string."""
|
||||
file = ImageFile(source="https://example.com/image.png")
|
||||
|
||||
assert isinstance(file.source, FileUrl)
|
||||
assert file.source.url == "https://example.com/image.png"
|
||||
|
||||
def test_image_file_from_file_url(self):
|
||||
"""Test creating ImageFile from FileUrl instance."""
|
||||
url = FileUrl(url="https://example.com/photo.jpg")
|
||||
file = ImageFile(source=url)
|
||||
|
||||
assert file.source is url
|
||||
assert file.content_type == "image/jpeg"
|
||||
Reference in New Issue
Block a user