feat: add core file types and content detection

This commit is contained in:
Greyson LaLonde
2026-01-21 18:23:36 -05:00
parent 741bf12bf4
commit 22f1e21d69
4 changed files with 468 additions and 2 deletions

View File

@@ -0,0 +1,214 @@
"""Content-type specific file classes."""
from __future__ import annotations
from abc import ABC
from pathlib import Path
from typing import Literal, Self
from pydantic import BaseModel, Field, field_validator
from crewai.utilities.files.file import (
FileBytes,
FilePath,
FileSource,
FileStream,
)
FileMode = Literal["strict", "auto", "warn", "chunk"]
ImageExtension = Literal[
".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".tiff", ".tif", ".svg"
]
ImageContentType = Literal[
"image/png",
"image/jpeg",
"image/gif",
"image/webp",
"image/bmp",
"image/tiff",
"image/svg+xml",
]
PDFExtension = Literal[".pdf"]
PDFContentType = Literal["application/pdf"]
TextExtension = Literal[
".txt",
".md",
".rst",
".csv",
".json",
".xml",
".yaml",
".yml",
".html",
".htm",
".log",
".ini",
".cfg",
".conf",
]
TextContentType = Literal[
"text/plain",
"text/markdown",
"text/csv",
"application/json",
"application/xml",
"text/xml",
"application/x-yaml",
"text/yaml",
"text/html",
]
AudioExtension = Literal[
".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a", ".wma", ".aiff", ".opus"
]
AudioContentType = Literal[
"audio/mpeg",
"audio/wav",
"audio/x-wav",
"audio/ogg",
"audio/flac",
"audio/aac",
"audio/mp4",
"audio/x-ms-wma",
"audio/aiff",
"audio/opus",
]
VideoExtension = Literal[
".mp4", ".avi", ".mkv", ".mov", ".webm", ".flv", ".wmv", ".m4v", ".mpeg", ".mpg"
]
VideoContentType = Literal[
"video/mp4",
"video/x-msvideo",
"video/x-matroska",
"video/quicktime",
"video/webm",
"video/x-flv",
"video/x-ms-wmv",
"video/mpeg",
]
class BaseFile(ABC, BaseModel):
"""Abstract base class for typed file wrappers.
Provides common functionality for all file types including:
- File source management
- Content reading
- Dict unpacking support (`**` syntax)
- Per-file mode mode
Can be unpacked with ** syntax: `{**ImageFile(source="./chart.png")}`
which unpacks to: `{"chart": <ImageFile instance>}` using filename stem as key.
Attributes:
source: The underlying file source (path, bytes, or stream).
mode: How to handle this file if it exceeds provider limits.
"""
source: FileSource = Field(description="The underlying file source.")
mode: FileMode = Field(
default="auto",
description="How to handle if file exceeds limits: strict, auto, warn, chunk.",
)
@field_validator("source", mode="before")
@classmethod
def _normalize_source(cls, v: str | Path | bytes | FileSource) -> FileSource:
"""Convert raw input to appropriate source type."""
if isinstance(v, (FilePath, FileBytes, FileStream)):
return v
if isinstance(v, Path):
return FilePath(path=v)
if isinstance(v, str):
return FilePath(path=Path(v))
if isinstance(v, bytes):
return FileBytes(data=v)
if hasattr(v, "read") and hasattr(v, "seek"):
return FileStream(stream=v)
raise ValueError(f"Cannot convert {type(v).__name__} to file source")
@property
def filename(self) -> str | None:
"""Get the filename from the source."""
return self.source.filename
@property
def content_type(self) -> str:
"""Get the content type from the source."""
return self.source.content_type
def read(self) -> bytes:
"""Read the file content as bytes."""
return self.source.read()
def read_text(self, encoding: str = "utf-8") -> str:
"""Read the file content as string."""
return self.read().decode(encoding)
@property
def _unpack_key(self) -> str:
"""Get the key to use when unpacking (filename stem)."""
if self.source.filename:
return Path(self.source.filename).stem
return "file"
def keys(self) -> list[str]:
"""Return keys for dict unpacking."""
return [self._unpack_key]
def __getitem__(self, key: str) -> Self:
"""Return self for dict unpacking."""
if key == self._unpack_key:
return self
raise KeyError(key)
class ImageFile(BaseFile):
"""File representing an image.
Supports common image formats: PNG, JPEG, GIF, WebP, BMP, TIFF, SVG.
"""
class PDFFile(BaseFile):
"""File representing a PDF document."""
class TextFile(BaseFile):
"""File representing a text document.
Supports common text formats: TXT, MD, RST, CSV, JSON, XML, YAML, HTML.
"""
class AudioFile(BaseFile):
"""File representing an audio file.
Supports common audio formats: MP3, WAV, OGG, FLAC, AAC, M4A, WMA.
"""
class VideoFile(BaseFile):
"""File representing a video file.
Supports common video formats: MP4, AVI, MKV, MOV, WebM, FLV, WMV.
"""
class File(BaseFile):
"""Generic file that auto-detects the appropriate type.
Use this when you don't want to specify the exact file type.
The content type is automatically detected from the file contents.
Example:
>>> file = File(source="./document.pdf")
>>> file = File(source="./image.png")
>>> file = File(source=some_bytes)
"""

View File

@@ -0,0 +1,158 @@
"""Base file class for handling file inputs in tasks."""
from __future__ import annotations
from pathlib import Path
from typing import Annotated, Any, BinaryIO, cast
import magic
from pydantic import (
BaseModel,
BeforeValidator,
Field,
GetCoreSchemaHandler,
PrivateAttr,
model_validator,
)
from pydantic_core import CoreSchema, core_schema
def detect_content_type(data: bytes) -> str:
"""Detect MIME type from file content.
Args:
data: Raw bytes to analyze.
Returns:
The detected MIME type.
"""
return magic.from_buffer(data, mime=True)
class _BinaryIOValidator:
"""Pydantic validator for BinaryIO types."""
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_plain_validator_function(
cls._validate,
serialization=core_schema.plain_serializer_function_ser_schema(
lambda x: None, info_arg=False
),
)
@staticmethod
def _validate(value: Any) -> BinaryIO:
if hasattr(value, "read") and hasattr(value, "seek"):
return cast(BinaryIO, value)
raise ValueError("Expected a binary file-like object with read() and seek()")
ValidatedBinaryIO = Annotated[BinaryIO, _BinaryIOValidator()]
class FilePath(BaseModel):
"""File loaded from a filesystem path."""
path: Path = Field(description="Path to the file on the filesystem.")
_content: bytes | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _validate_file_exists(self) -> FilePath:
"""Validate that the file exists."""
if not self.path.exists():
raise ValueError(f"File not found: {self.path}")
if not self.path.is_file():
raise ValueError(f"Path is not a file: {self.path}")
return self
@property
def filename(self) -> str:
"""Get the filename from the path."""
return self.path.name
@property
def content_type(self) -> str:
"""Get the content type by reading file content."""
return detect_content_type(self.read())
def read(self) -> bytes:
"""Read the file content from disk."""
if self._content is None:
self._content = self.path.read_bytes()
return self._content
class FileBytes(BaseModel):
"""File created from raw bytes content."""
data: bytes = Field(description="Raw bytes content of the file.")
filename: str | None = Field(default=None, description="Optional filename.")
@property
def content_type(self) -> str:
"""Get the content type from the data."""
return detect_content_type(self.data)
def read(self) -> bytes:
"""Return the bytes content."""
return self.data
class FileStream(BaseModel):
"""File loaded from a file-like stream."""
stream: ValidatedBinaryIO = Field(description="Binary file stream.")
filename: str | None = Field(default=None, description="Optional filename.")
_content: bytes | None = PrivateAttr(default=None)
def model_post_init(self, __context: object) -> None:
"""Extract filename from stream if not provided."""
if self.filename is None:
name = getattr(self.stream, "name", None)
if name is not None:
object.__setattr__(self, "filename", Path(name).name)
@property
def content_type(self) -> str:
"""Get the content type from stream content."""
return detect_content_type(self.read())
def read(self) -> bytes:
"""Read the stream content. Content is cached after first read."""
if self._content is None:
position = self.stream.tell()
self.stream.seek(0)
self._content = self.stream.read()
self.stream.seek(position)
return self._content
def close(self) -> None:
"""Close the underlying stream."""
self.stream.close()
FileSource = FilePath | FileBytes | FileStream
def _normalize_source(value: Any) -> FileSource:
"""Convert raw input to appropriate source type."""
if isinstance(value, (FilePath, FileBytes, FileStream)):
return value
if isinstance(value, Path):
return FilePath(path=value)
if isinstance(value, str):
return FilePath(path=Path(value))
if isinstance(value, bytes):
return FileBytes(data=value)
if hasattr(value, "read") and hasattr(value, "seek"):
return FileStream(stream=value)
raise ValueError(f"Cannot convert {type(value).__name__} to file source")
RawFileInput = str | Path | bytes
FileSourceInput = Annotated[
RawFileInput | FileSource, BeforeValidator(_normalize_source)
]

View File

@@ -0,0 +1,84 @@
"""Resolved file types representing different delivery methods for file content."""
from abc import ABC
from dataclasses import dataclass
from datetime import datetime
@dataclass(frozen=True)
class ResolvedFile(ABC):
"""Base class for resolved file representations.
A ResolvedFile represents the final form of a file ready for delivery
to an LLM provider, whether inline or via reference.
Attributes:
content_type: MIME type of the file content.
"""
content_type: str
@dataclass(frozen=True)
class InlineBase64(ResolvedFile):
"""File content encoded as base64 string.
Used by most providers for inline file content in messages.
Attributes:
content_type: MIME type of the file content.
data: Base64-encoded file content.
"""
data: str
@dataclass(frozen=True)
class InlineBytes(ResolvedFile):
"""File content as raw bytes.
Used by providers like Bedrock that accept raw bytes instead of base64.
Attributes:
content_type: MIME type of the file content.
data: Raw file bytes.
"""
data: bytes
@dataclass(frozen=True)
class FileReference(ResolvedFile):
"""Reference to an uploaded file.
Used when files are uploaded via provider File APIs.
Attributes:
content_type: MIME type of the file content.
file_id: Provider-specific file identifier.
provider: Name of the provider the file was uploaded to.
expires_at: When the uploaded file expires (if applicable).
file_uri: Optional URI for accessing the file (used by Gemini).
"""
file_id: str
provider: str
expires_at: datetime | None = None
file_uri: str | None = None
@dataclass(frozen=True)
class UrlReference(ResolvedFile):
"""Reference to a file accessible via URL.
Used by providers that support fetching files from URLs.
Attributes:
content_type: MIME type of the file content.
url: URL where the file can be accessed.
"""
url: str
ResolvedFileType = InlineBase64 | InlineBytes | FileReference | UrlReference

View File

@@ -1,8 +1,8 @@
"""Types for CrewAI utilities."""
from typing import Any, Literal
from typing import Any, Literal, TypedDict
from typing_extensions import TypedDict
from crewai.utilities.files import FileInput
class LLMMessage(TypedDict):
@@ -15,3 +15,13 @@ class LLMMessage(TypedDict):
role: Literal["user", "assistant", "system"]
content: str | list[dict[str, Any]]
class KickoffInputs(TypedDict, total=False):
"""Type for crew kickoff inputs.
Attributes:
files: Named file inputs accessible to tasks during execution.
"""
files: dict[str, FileInput]