refactor: improve factory typing with specific provider and uploader types

This commit is contained in:
Greyson LaLonde
2026-01-22 12:22:28 -05:00
parent 6147d4eb2e
commit 9fec81f976

View File

@@ -3,86 +3,121 @@
from __future__ import annotations
import logging
from typing import Literal, TypedDict, overload, reveal_type
from typing import Literal, TypeAlias, TypedDict, overload
from typing_extensions import Unpack
from typing_extensions import NotRequired, Unpack
from crewai.files.uploaders.base import FileUploader
from crewai.files.uploaders.anthropic import AnthropicFileUploader
from crewai.files.uploaders.bedrock import BedrockFileUploader
from crewai.files.uploaders.gemini import GeminiFileUploader
from crewai.files.uploaders.openai import OpenAIFileUploader
logger = logging.getLogger(__name__)
ProviderType = Literal[
"gemini", "google", "anthropic", "claude", "openai", "gpt", "bedrock", "aws"
]
UnknownProvider = str
FileUploaderType: TypeAlias = (
GeminiFileUploader
| AnthropicFileUploader
| BedrockFileUploader
| OpenAIFileUploader
)
GeminiProviderType = Literal["gemini", "google"]
AnthropicProviderType = Literal["anthropic", "claude"]
OpenAIProviderType = Literal["openai", "gpt"]
BedrockProviderType = Literal["bedrock", "aws"]
ProviderType: TypeAlias = (
GeminiProviderType
| AnthropicProviderType
| OpenAIProviderType
| BedrockProviderType
)
class _BaseOpts(TypedDict):
"""Kwargs for uploader factory."""
api_key: NotRequired[str | None]
class OpenAIOpts(_BaseOpts):
"""Kwargs for openai uploader factory."""
chunk_size: NotRequired[int]
class GeminiOpts(_BaseOpts):
"""Kwargs for gemini uploader factory."""
class AnthropicOpts(_BaseOpts):
"""Kwargs for anthropic uploader factory."""
class BedrockOpts(TypedDict):
"""Kwargs for bedrock uploader factory."""
bucket_name: NotRequired[str | None]
bucket_owner: NotRequired[str | None]
prefix: NotRequired[str]
region: NotRequired[str | None]
class AllOptions(TypedDict):
"""Kwargs for uploader factory."""
api_key: str | None
chunk_size: int
bucket_name: str
bucket_owner: str
prefix: str
region: str
@overload
def get_uploader(provider: UnknownProvider, /) -> None:
"""Get file uploader for unknown provider."""
api_key: NotRequired[str | None]
chunk_size: NotRequired[int]
bucket_name: NotRequired[str | None]
bucket_owner: NotRequired[str | None]
prefix: NotRequired[str]
region: NotRequired[str | None]
@overload
def get_uploader(
provider: Literal["gemini", "google"],
*,
api_key: str | None = ...,
) -> FileUploader:
provider: GeminiProviderType,
**kwargs: Unpack[GeminiOpts],
) -> GeminiFileUploader:
"""Get Gemini file uploader."""
@overload
def get_uploader(
provider: Literal["anthropic", "claude"],
*,
api_key: str | None = ...,
) -> FileUploader:
provider: AnthropicProviderType,
**kwargs: Unpack[AnthropicOpts],
) -> AnthropicFileUploader:
"""Get Anthropic file uploader."""
@overload
def get_uploader(
provider: Literal["openai", "gpt"],
*,
api_key: str | None = ...,
chunk_size: int = ...,
) -> FileUploader | None:
provider: OpenAIProviderType,
**kwargs: Unpack[OpenAIOpts],
) -> OpenAIFileUploader:
"""Get OpenAI file uploader."""
@overload
def get_uploader(
provider: Literal["bedrock", "aws"],
*,
bucket_name: str | None = ...,
bucket_owner: str | None = ...,
prefix: str = ...,
region: str | None = ...,
) -> FileUploader | None:
provider: BedrockProviderType,
**kwargs: Unpack[BedrockOpts],
) -> BedrockFileUploader:
"""Get Bedrock file uploader."""
@overload
def get_uploader(
provider: ProviderType | UnknownProvider, **kwargs: Unpack[AllOptions]
) -> FileUploader | None:
provider: ProviderType, **kwargs: Unpack[AllOptions]
) -> FileUploaderType:
"""Get any file uploader."""
def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
def get_uploader(
provider: ProviderType, **kwargs: Unpack[AllOptions]
) -> FileUploaderType:
"""Get a file uploader for a specific provider.
Args:
@@ -103,7 +138,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
logger.warning(
"google-genai not installed. Install with: pip install google-genai"
)
return None
raise
if "anthropic" in provider_lower or "claude" in provider_lower:
try:
@@ -114,7 +149,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
logger.warning(
"anthropic not installed. Install with: pip install anthropic"
)
return None
raise
if "openai" in provider_lower or "gpt" in provider_lower:
try:
@@ -126,7 +161,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
)
except ImportError:
logger.warning("openai not installed. Install with: pip install openai")
return None
raise
if "bedrock" in provider_lower or "aws" in provider_lower:
import os
@@ -139,7 +174,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
"Bedrock S3 uploader not configured. "
"Set CREWAI_BEDROCK_S3_BUCKET environment variable to enable."
)
return None
raise
try:
from crewai.files.uploaders.bedrock import BedrockFileUploader
@@ -151,11 +186,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
)
except ImportError:
logger.warning("boto3 not installed. Install with: pip install boto3")
return None
raise
logger.debug(f"No file uploader available for provider: {provider}")
return None
t = get_uploader("openai")
reveal_type(t)
raise