refactor: fix IDE warnings and add Literal types to constraints

- Add Literal types for ImageFormat, AudioFormat, VideoFormat, ProviderName
- Convert methods to @staticmethod where appropriate
- Remove redundant default parameter values
- Fix variable shadowing in nested functions
- Make magic import optional with mimetypes fallback
- Add docstrings to inner functions
This commit is contained in:
Greyson LaLonde
2026-01-22 02:54:29 -05:00
parent 1353cb2a33
commit 0a250a45ce
9 changed files with 239 additions and 115 deletions

View File

@@ -237,9 +237,12 @@ async def acleanup_uploaded_files(
if delete_from_provider:
semaphore = asyncio.Semaphore(max_concurrency)
async def delete_one(uploader: FileUploader, upload: CachedUpload) -> bool:
async def delete_one(file_uploader: FileUploader, cached: CachedUpload) -> bool:
"""Delete a single file with semaphore limiting."""
async with semaphore:
return await _asafe_delete(uploader, upload.file_id, upload.provider)
return await _asafe_delete(
file_uploader, cached.file_id, cached.provider
)
tasks: list[asyncio.Task[bool]] = []
for provider, uploads in provider_uploads.items():
@@ -251,7 +254,7 @@ async def acleanup_uploaded_files(
continue
tasks.extend(
asyncio.create_task(delete_one(uploader, upload)) for upload in uploads
asyncio.create_task(delete_one(uploader, cached)) for cached in uploads
)
results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -291,19 +294,20 @@ async def acleanup_expired_files(
if delete_from_provider and expired_entries:
semaphore = asyncio.Semaphore(max_concurrency)
async def delete_expired(upload: CachedUpload) -> None:
async def delete_expired(cached: CachedUpload) -> None:
"""Delete an expired file with semaphore limiting."""
async with semaphore:
uploader = get_uploader(upload.provider)
if uploader is not None:
file_uploader = get_uploader(cached.provider)
if file_uploader is not None:
try:
await uploader.adelete(upload.file_id)
await file_uploader.adelete(cached.file_id)
except Exception as e:
logger.debug(
f"Could not delete expired file {upload.file_id}: {e}"
f"Could not delete expired file {cached.file_id}: {e}"
)
await asyncio.gather(
*[delete_expired(upload) for upload in expired_entries],
*[delete_expired(cached) for cached in expired_entries],
return_exceptions=True,
)
@@ -337,18 +341,19 @@ async def acleanup_provider_files(
semaphore = asyncio.Semaphore(max_concurrency)
async def delete_file(file_id: str) -> bool:
async def delete_single(target_file_id: str) -> bool:
"""Delete a single file with semaphore limiting."""
async with semaphore:
return await uploader.adelete(file_id)
return await uploader.adelete(target_file_id)
if delete_all_from_provider:
try:
files = uploader.list_files()
tasks = []
for file_info in files:
file_id = file_info.get("id") or file_info.get("name")
if file_id:
tasks.append(delete_file(file_id))
fid = file_info.get("id") or file_info.get("name")
if fid:
tasks.append(delete_single(fid))
results = await asyncio.gather(*tasks, return_exceptions=True)
deleted = sum(1 for r in results if r is True)
except Exception as e:
@@ -357,7 +362,7 @@ async def acleanup_provider_files(
uploads = await cache.aget_all_for_provider(provider)
tasks = []
for upload in uploads:
tasks.append(delete_file(upload.file_id))
tasks.append(delete_single(upload.file_id))
results = await asyncio.gather(*tasks, return_exceptions=True)
for upload, result in zip(uploads, results, strict=False):
if result is True:

View File

@@ -3,11 +3,11 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
import mimetypes
from pathlib import Path
from typing import Annotated, Any, BinaryIO, Protocol, cast, runtime_checkable
import aiofiles
import magic
from pydantic import (
BaseModel,
BeforeValidator,
@@ -52,17 +52,30 @@ ValidatedAsyncReadable = Annotated[AsyncReadable, _AsyncReadableValidator()]
DEFAULT_MAX_FILE_SIZE_BYTES = 500 * 1024 * 1024 # 500MB
def detect_content_type(data: bytes) -> str:
def detect_content_type(data: bytes, filename: str | None = None) -> str:
"""Detect MIME type from file content.
Uses python-magic if available for accurate content-based detection,
falls back to mimetypes module using filename extension.
Args:
data: Raw bytes to analyze.
filename: Optional filename for extension-based fallback.
Returns:
The detected MIME type.
"""
result: str = magic.from_buffer(data, mime=True)
return result
try:
import magic
result: str = magic.from_buffer(data, mime=True)
return result
except ImportError:
if filename:
mime_type, _ = mimetypes.guess_type(filename)
if mime_type:
return mime_type
return "application/octet-stream"
class _BinaryIOValidator:
@@ -139,7 +152,7 @@ class FilePath(BaseModel):
@property
def content_type(self) -> str:
"""Get the content type by reading file content."""
return detect_content_type(self.read())
return detect_content_type(self.read(), self.filename)
def read(self) -> bytes:
"""Read the file content from disk."""
@@ -190,7 +203,7 @@ class FileBytes(BaseModel):
@property
def content_type(self) -> str:
"""Get the content type from the data."""
return detect_content_type(self.data)
return detect_content_type(self.data, self.filename)
def read(self) -> bytes:
"""Return the bytes content."""
@@ -242,7 +255,7 @@ class FileStream(BaseModel):
@property
def content_type(self) -> str:
"""Get the content type from stream content."""
return detect_content_type(self.read())
return detect_content_type(self.read(), self.filename)
def read(self) -> bytes:
"""Read the stream content. Content is cached after first read."""
@@ -310,7 +323,7 @@ class AsyncFileStream(BaseModel):
"""Get the content type from stream content. Requires aread() first."""
if self._content is None:
raise RuntimeError("Call aread() first to load content")
return detect_content_type(self._content)
return detect_content_type(self._content, self.filename)
async def aread(self) -> bytes:
"""Async read the stream content. Content is cached after first read."""

View File

@@ -1,6 +1,99 @@
"""Provider-specific file constraints for multimodal content."""
from dataclasses import dataclass
from typing import Literal
ImageFormat = Literal[
"image/png",
"image/jpeg",
"image/gif",
"image/webp",
"image/heic",
"image/heif",
]
AudioFormat = Literal[
"audio/mp3",
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/flac",
"audio/aac",
"audio/m4a",
"audio/opus",
]
VideoFormat = Literal[
"video/mp4",
"video/mpeg",
"video/webm",
"video/quicktime",
"video/x-msvideo",
"video/x-flv",
]
ProviderName = Literal[
"anthropic",
"openai",
"gemini",
"bedrock",
"azure",
]
# Pre-typed format tuples for common combinations
DEFAULT_IMAGE_FORMATS: tuple[ImageFormat, ...] = (
"image/png",
"image/jpeg",
"image/gif",
"image/webp",
)
GEMINI_IMAGE_FORMATS: tuple[ImageFormat, ...] = (
"image/png",
"image/jpeg",
"image/gif",
"image/webp",
"image/heic",
"image/heif",
)
DEFAULT_AUDIO_FORMATS: tuple[AudioFormat, ...] = (
"audio/mp3",
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/flac",
"audio/aac",
"audio/m4a",
)
GEMINI_AUDIO_FORMATS: tuple[AudioFormat, ...] = (
"audio/mp3",
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/flac",
"audio/aac",
"audio/m4a",
"audio/opus",
)
DEFAULT_VIDEO_FORMATS: tuple[VideoFormat, ...] = (
"video/mp4",
"video/mpeg",
"video/webm",
"video/quicktime",
)
GEMINI_VIDEO_FORMATS: tuple[VideoFormat, ...] = (
"video/mp4",
"video/mpeg",
"video/webm",
"video/quicktime",
"video/x-msvideo",
"video/x-flv",
)
@dataclass(frozen=True)
@@ -19,12 +112,7 @@ class ImageConstraints:
max_width: int | None = None
max_height: int | None = None
max_images_per_request: int | None = None
supported_formats: tuple[str, ...] = (
"image/png",
"image/jpeg",
"image/gif",
"image/webp",
)
supported_formats: tuple[ImageFormat, ...] = DEFAULT_IMAGE_FORMATS
@dataclass(frozen=True)
@@ -52,15 +140,7 @@ class AudioConstraints:
max_size_bytes: int
max_duration_seconds: int | None = None
supported_formats: tuple[str, ...] = (
"audio/mp3",
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/flac",
"audio/aac",
"audio/m4a",
)
supported_formats: tuple[AudioFormat, ...] = DEFAULT_AUDIO_FORMATS
@dataclass(frozen=True)
@@ -75,12 +155,7 @@ class VideoConstraints:
max_size_bytes: int
max_duration_seconds: int | None = None
supported_formats: tuple[str, ...] = (
"video/mp4",
"video/mpeg",
"video/webm",
"video/quicktime",
)
supported_formats: tuple[VideoFormat, ...] = DEFAULT_VIDEO_FORMATS
@dataclass(frozen=True)
@@ -98,7 +173,7 @@ class ProviderConstraints:
file_upload_threshold_bytes: Size threshold above which to use file upload.
"""
name: str
name: ProviderName
image: ImageConstraints | None = None
pdf: PDFConstraints | None = None
audio: AudioConstraints | None = None
@@ -114,7 +189,6 @@ ANTHROPIC_CONSTRAINTS = ProviderConstraints(
max_size_bytes=5 * 1024 * 1024,
max_width=8000,
max_height=8000,
supported_formats=("image/png", "image/jpeg", "image/gif", "image/webp"),
),
pdf=PDFConstraints(
max_size_bytes=30 * 1024 * 1024,
@@ -129,9 +203,7 @@ OPENAI_CONSTRAINTS = ProviderConstraints(
image=ImageConstraints(
max_size_bytes=20 * 1024 * 1024,
max_images_per_request=10,
supported_formats=("image/png", "image/jpeg", "image/gif", "image/webp"),
),
pdf=None,
supports_file_upload=True,
file_upload_threshold_bytes=5 * 1024 * 1024,
)
@@ -140,41 +212,18 @@ GEMINI_CONSTRAINTS = ProviderConstraints(
name="gemini",
image=ImageConstraints(
max_size_bytes=100 * 1024 * 1024,
supported_formats=(
"image/png",
"image/jpeg",
"image/gif",
"image/webp",
"image/heic",
"image/heif",
),
supported_formats=GEMINI_IMAGE_FORMATS,
),
pdf=PDFConstraints(
max_size_bytes=50 * 1024 * 1024,
),
audio=AudioConstraints(
max_size_bytes=100 * 1024 * 1024,
supported_formats=(
"audio/mp3",
"audio/mpeg",
"audio/wav",
"audio/ogg",
"audio/flac",
"audio/aac",
"audio/m4a",
"audio/opus",
),
supported_formats=GEMINI_AUDIO_FORMATS,
),
video=VideoConstraints(
max_size_bytes=2 * 1024 * 1024 * 1024,
supported_formats=(
"video/mp4",
"video/mpeg",
"video/webm",
"video/quicktime",
"video/x-msvideo",
"video/x-flv",
),
supported_formats=GEMINI_VIDEO_FORMATS,
),
supports_file_upload=True,
file_upload_threshold_bytes=20 * 1024 * 1024,
@@ -186,7 +235,6 @@ BEDROCK_CONSTRAINTS = ProviderConstraints(
max_size_bytes=4_608_000,
max_width=8000,
max_height=8000,
supported_formats=("image/png", "image/jpeg", "image/gif", "image/webp"),
),
pdf=PDFConstraints(
max_size_bytes=3_840_000,
@@ -199,9 +247,7 @@ AZURE_CONSTRAINTS = ProviderConstraints(
image=ImageConstraints(
max_size_bytes=20 * 1024 * 1024,
max_images_per_request=10,
supported_formats=("image/png", "image/jpeg", "image/gif", "image/webp"),
),
pdf=None,
)

View File

@@ -89,7 +89,8 @@ class FileProcessor:
raise_on_error = mode == FileHandling.STRICT
return validate_file(file, self.constraints, raise_on_error=raise_on_error)
def _get_mode(self, file: FileInput) -> FileHandling:
@staticmethod
def _get_mode(file: FileInput) -> FileHandling:
"""Get the mode mode for a file.
Args:
@@ -201,32 +202,33 @@ class FileProcessor:
"""
semaphore = asyncio.Semaphore(max_concurrency)
async def process_one(
name: str, file: FileInput
async def process_single(
key: str, input_file: FileInput
) -> tuple[str, FileInput | Sequence[FileInput]]:
"""Process a single file with semaphore limiting."""
async with semaphore:
loop = asyncio.get_running_loop()
processed = await loop.run_in_executor(None, self.process, file)
return name, processed
result = await loop.run_in_executor(None, self.process, input_file)
return key, result
tasks = [process_one(n, f) for n, f in files.items()]
results = await asyncio.gather(*tasks, return_exceptions=True)
tasks = [process_single(n, f) for n, f in files.items()]
gather_results = await asyncio.gather(*tasks, return_exceptions=True)
output: dict[str, FileInput] = {}
for result in results:
if isinstance(result, BaseException):
logger.error(f"Processing failed: {result}")
for item in gather_results:
if isinstance(item, BaseException):
logger.error(f"Processing failed: {item}")
continue
name, processed = result
entry_name, processed = item
if isinstance(processed, Sequence) and not isinstance(
processed, (str, bytes)
):
for i, chunk in enumerate(processed):
output[f"{name}_chunk_{i}"] = chunk
output[f"{entry_name}_chunk_{i}"] = chunk
elif isinstance(
processed, (AudioFile, File, ImageFile, PDFFile, TextFile, VideoFile)
):
output[name] = processed
output[entry_name] = processed
return output

View File

@@ -305,7 +305,7 @@ def get_image_dimensions(file: ImageFile) -> tuple[int, int] | None:
try:
with Image.open(io.BytesIO(content)) as img:
width, height = img.size
return (width, height)
return width, height
except Exception as e:
logger.warning(f"Failed to get image dimensions: {e}")
return None

View File

@@ -16,7 +16,11 @@ from crewai.files.content_types import (
)
from crewai.files.metrics import measure_operation
from crewai.files.processing.constraints import (
AudioConstraints,
ImageConstraints,
PDFConstraints,
ProviderConstraints,
VideoConstraints,
get_constraints_for_provider,
)
from crewai.files.resolved import (
@@ -91,7 +95,8 @@ class FileResolver:
upload_cache: UploadCache | None = None
_uploaders: dict[str, FileUploader] = field(default_factory=dict)
def _build_file_context(self, file: FileInput) -> FileContext:
@staticmethod
def _build_file_context(file: FileInput) -> FileContext:
"""Build context by reading file once.
Args:
@@ -149,6 +154,30 @@ class FileResolver:
"""
return {name: self.resolve(file, provider) for name, file in files.items()}
@staticmethod
def _get_type_constraint(
content_type: str,
constraints: ProviderConstraints,
) -> ImageConstraints | PDFConstraints | AudioConstraints | VideoConstraints | None:
"""Get type-specific constraint based on content type.
Args:
content_type: MIME type of the file.
constraints: Provider constraints.
Returns:
Type-specific constraint or None if not found.
"""
if content_type.startswith("image/"):
return constraints.image
if content_type == "application/pdf":
return constraints.pdf
if content_type.startswith("audio/"):
return constraints.audio
if content_type.startswith("video/"):
return constraints.video
return None
def _should_upload(
self,
file: FileInput,
@@ -158,6 +187,10 @@ class FileResolver:
) -> bool:
"""Determine if a file should be uploaded rather than inlined.
Uses type-specific constraints to make smarter decisions:
- Checks if file exceeds type-specific inline size limits
- Falls back to general threshold if no type-specific constraint
Args:
file: The file to check.
provider: Provider name.
@@ -173,8 +206,21 @@ class FileResolver:
if self.config.prefer_upload:
return True
content_type = file.content_type
type_constraint = self._get_type_constraint(content_type, constraints)
if type_constraint is not None:
# Check if file exceeds type-specific inline limit
if file_size > type_constraint.max_size_bytes:
logger.debug(
f"File {file.filename} ({file_size}B) exceeds {content_type} "
f"inline limit ({type_constraint.max_size_bytes}B) for {provider}"
)
return True
# Fall back to general threshold
threshold = self.config.upload_threshold_bytes
if threshold is None and constraints is not None:
if threshold is None:
threshold = constraints.file_upload_threshold_bytes
if threshold is not None and file_size > threshold:
@@ -239,8 +285,8 @@ class FileResolver:
file_uri=result.file_uri,
)
@staticmethod
def _upload_with_retry(
self,
uploader: FileUploader,
file: FileInput,
provider: str,
@@ -312,13 +358,14 @@ class FileResolver:
"""Resolve a file as inline content.
Args:
file: The file to resolve.
file: The file to resolve (used for logging).
provider: Provider name.
context: Pre-computed file context.
Returns:
InlineBase64 or InlineBytes depending on provider.
"""
logger.debug(f"Resolving {file.filename} as inline for {provider}")
if self.config.use_bytes_for_bedrock and "bedrock" in provider:
return InlineBytes(
content_type=context.content_type,
@@ -374,21 +421,24 @@ class FileResolver:
"""
semaphore = asyncio.Semaphore(max_concurrency)
async def resolve_one(name: str, file: FileInput) -> tuple[str, ResolvedFile]:
async def resolve_single(
entry_key: str, input_file: FileInput
) -> tuple[str, ResolvedFile]:
"""Resolve a single file with semaphore limiting."""
async with semaphore:
resolved = await self.aresolve(file, provider)
return name, resolved
entry_resolved = await self.aresolve(input_file, provider)
return entry_key, entry_resolved
tasks = [resolve_one(n, f) for n, f in files.items()]
results = await asyncio.gather(*tasks, return_exceptions=True)
tasks = [resolve_single(n, f) for n, f in files.items()]
gather_results = await asyncio.gather(*tasks, return_exceptions=True)
output: dict[str, ResolvedFile] = {}
for result in results:
if isinstance(result, BaseException):
logger.error(f"Resolution failed: {result}")
for item in gather_results:
if isinstance(item, BaseException):
logger.error(f"Resolution failed: {item}")
continue
name, resolved = result
output[name] = resolved
key, resolved = item
output[key] = resolved
return output
@@ -451,8 +501,8 @@ class FileResolver:
file_uri=result.file_uri,
)
@staticmethod
async def _aupload_with_retry(
self,
uploader: FileUploader,
file: FileInput,
provider: str,
@@ -559,17 +609,24 @@ def create_resolver(
"""Create a configured FileResolver.
Args:
provider: Optional provider name for provider-specific configuration.
provider: Optional provider name to load default threshold from constraints.
prefer_upload: Whether to prefer upload over inline.
upload_threshold_bytes: Size threshold for using upload.
upload_threshold_bytes: Size threshold for using upload. If None and
provider is specified, uses provider's default threshold.
enable_cache: Whether to enable upload caching.
Returns:
Configured FileResolver instance.
"""
threshold = upload_threshold_bytes
if threshold is None and provider is not None:
constraints = get_constraints_for_provider(provider)
if constraints is not None:
threshold = constraints.file_upload_threshold_bytes
config = FileResolverConfig(
prefer_upload=prefer_upload,
upload_threshold_bytes=upload_threshold_bytes,
upload_threshold_bytes=threshold,
)
cache = UploadCache() if enable_cache else None

View File

@@ -28,7 +28,6 @@ if TYPE_CHECKING:
FileInput = AudioFile | File | ImageFile | PDFFile | TextFile | VideoFile
logger = logging.getLogger(__name__)
DEFAULT_TTL_SECONDS = 24 * 60 * 60 # 24 hours
@@ -139,7 +138,6 @@ class UploadCache:
)
else:
self._cache = Cache(
Cache.MEMORY,
serializer=PickleSerializer(),
namespace=namespace,
)
@@ -406,7 +404,8 @@ class UploadCache:
results.append(cached)
return results
def _run_sync(self, coro: Any) -> Any:
@staticmethod
def _run_sync(coro: Any) -> Any:
"""Run an async coroutine from sync context without blocking event loop."""
try:
loop = asyncio.get_running_loop()
@@ -549,7 +548,7 @@ def _cleanup_on_exit() -> None:
from crewai.files.cleanup import cleanup_uploaded_files
try:
cleanup_uploaded_files(_default_cache, delete_from_provider=True)
cleanup_uploaded_files(_default_cache)
except Exception as e:
logger.debug(f"Error during exit cleanup: {e}")

View File

@@ -204,7 +204,8 @@ class BedrockFileUploader(FileUploader):
"""
return f"s3://{self.bucket_name}/{key}"
def _get_transfer_config(self) -> Any:
@staticmethod
def _get_transfer_config() -> Any:
"""Get boto3 TransferConfig for multipart uploads."""
from boto3.s3.transfer import TransferConfig

View File

@@ -388,7 +388,8 @@ class OpenAIFileUploader(FileUploader):
logger.debug(f"Failed to cancel upload: {cancel_err}")
raise
def _classify_error(self, e: Exception, filename: str | None) -> Exception:
@staticmethod
def _classify_error(e: Exception, filename: str | None) -> Exception:
"""Classify an exception as transient or permanent.
Args: