refactor: improve factory typing with specific provider and uploader types

This commit is contained in:
Greyson LaLonde
2026-01-22 12:22:28 -05:00
parent 6147d4eb2e
commit 9fec81f976

View File

@@ -3,86 +3,121 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Literal, TypedDict, overload, reveal_type from typing import Literal, TypeAlias, TypedDict, overload
from typing_extensions import Unpack from typing_extensions import NotRequired, Unpack
from crewai.files.uploaders.base import FileUploader from crewai.files.uploaders.anthropic import AnthropicFileUploader
from crewai.files.uploaders.bedrock import BedrockFileUploader
from crewai.files.uploaders.gemini import GeminiFileUploader
from crewai.files.uploaders.openai import OpenAIFileUploader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ProviderType = Literal[ FileUploaderType: TypeAlias = (
"gemini", "google", "anthropic", "claude", "openai", "gpt", "bedrock", "aws" GeminiFileUploader
] | AnthropicFileUploader
UnknownProvider = str | BedrockFileUploader
| OpenAIFileUploader
)
GeminiProviderType = Literal["gemini", "google"]
AnthropicProviderType = Literal["anthropic", "claude"]
OpenAIProviderType = Literal["openai", "gpt"]
BedrockProviderType = Literal["bedrock", "aws"]
ProviderType: TypeAlias = (
GeminiProviderType
| AnthropicProviderType
| OpenAIProviderType
| BedrockProviderType
)
class _BaseOpts(TypedDict):
"""Kwargs for uploader factory."""
api_key: NotRequired[str | None]
class OpenAIOpts(_BaseOpts):
"""Kwargs for openai uploader factory."""
chunk_size: NotRequired[int]
class GeminiOpts(_BaseOpts):
"""Kwargs for gemini uploader factory."""
class AnthropicOpts(_BaseOpts):
"""Kwargs for anthropic uploader factory."""
class BedrockOpts(TypedDict):
"""Kwargs for bedrock uploader factory."""
bucket_name: NotRequired[str | None]
bucket_owner: NotRequired[str | None]
prefix: NotRequired[str]
region: NotRequired[str | None]
class AllOptions(TypedDict): class AllOptions(TypedDict):
"""Kwargs for uploader factory.""" """Kwargs for uploader factory."""
api_key: str | None api_key: NotRequired[str | None]
chunk_size: int chunk_size: NotRequired[int]
bucket_name: str bucket_name: NotRequired[str | None]
bucket_owner: str bucket_owner: NotRequired[str | None]
prefix: str prefix: NotRequired[str]
region: str region: NotRequired[str | None]
@overload
def get_uploader(provider: UnknownProvider, /) -> None:
"""Get file uploader for unknown provider."""
@overload @overload
def get_uploader( def get_uploader(
provider: Literal["gemini", "google"], provider: GeminiProviderType,
*, **kwargs: Unpack[GeminiOpts],
api_key: str | None = ..., ) -> GeminiFileUploader:
) -> FileUploader:
"""Get Gemini file uploader.""" """Get Gemini file uploader."""
@overload @overload
def get_uploader( def get_uploader(
provider: Literal["anthropic", "claude"], provider: AnthropicProviderType,
*, **kwargs: Unpack[AnthropicOpts],
api_key: str | None = ..., ) -> AnthropicFileUploader:
) -> FileUploader:
"""Get Anthropic file uploader.""" """Get Anthropic file uploader."""
@overload @overload
def get_uploader( def get_uploader(
provider: Literal["openai", "gpt"], provider: OpenAIProviderType,
*, **kwargs: Unpack[OpenAIOpts],
api_key: str | None = ..., ) -> OpenAIFileUploader:
chunk_size: int = ...,
) -> FileUploader | None:
"""Get OpenAI file uploader.""" """Get OpenAI file uploader."""
@overload @overload
def get_uploader( def get_uploader(
provider: Literal["bedrock", "aws"], provider: BedrockProviderType,
*, **kwargs: Unpack[BedrockOpts],
bucket_name: str | None = ..., ) -> BedrockFileUploader:
bucket_owner: str | None = ...,
prefix: str = ...,
region: str | None = ...,
) -> FileUploader | None:
"""Get Bedrock file uploader.""" """Get Bedrock file uploader."""
@overload @overload
def get_uploader( def get_uploader(
provider: ProviderType | UnknownProvider, **kwargs: Unpack[AllOptions] provider: ProviderType, **kwargs: Unpack[AllOptions]
) -> FileUploader | None: ) -> FileUploaderType:
"""Get any file uploader.""" """Get any file uploader."""
def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def] def get_uploader(
provider: ProviderType, **kwargs: Unpack[AllOptions]
) -> FileUploaderType:
"""Get a file uploader for a specific provider. """Get a file uploader for a specific provider.
Args: Args:
@@ -103,7 +138,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
logger.warning( logger.warning(
"google-genai not installed. Install with: pip install google-genai" "google-genai not installed. Install with: pip install google-genai"
) )
return None raise
if "anthropic" in provider_lower or "claude" in provider_lower: if "anthropic" in provider_lower or "claude" in provider_lower:
try: try:
@@ -114,7 +149,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
logger.warning( logger.warning(
"anthropic not installed. Install with: pip install anthropic" "anthropic not installed. Install with: pip install anthropic"
) )
return None raise
if "openai" in provider_lower or "gpt" in provider_lower: if "openai" in provider_lower or "gpt" in provider_lower:
try: try:
@@ -126,7 +161,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
) )
except ImportError: except ImportError:
logger.warning("openai not installed. Install with: pip install openai") logger.warning("openai not installed. Install with: pip install openai")
return None raise
if "bedrock" in provider_lower or "aws" in provider_lower: if "bedrock" in provider_lower or "aws" in provider_lower:
import os import os
@@ -139,7 +174,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
"Bedrock S3 uploader not configured. " "Bedrock S3 uploader not configured. "
"Set CREWAI_BEDROCK_S3_BUCKET environment variable to enable." "Set CREWAI_BEDROCK_S3_BUCKET environment variable to enable."
) )
return None raise
try: try:
from crewai.files.uploaders.bedrock import BedrockFileUploader from crewai.files.uploaders.bedrock import BedrockFileUploader
@@ -151,11 +186,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
) )
except ImportError: except ImportError:
logger.warning("boto3 not installed. Install with: pip install boto3") logger.warning("boto3 not installed. Install with: pip install boto3")
return None raise
logger.debug(f"No file uploader available for provider: {provider}") logger.debug(f"No file uploader available for provider: {provider}")
return None raise
t = get_uploader("openai")
reveal_type(t)