mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-23 15:18:14 +00:00
refactor: improve factory typing with specific provider and uploader types
This commit is contained in:
@@ -3,86 +3,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Literal, TypedDict, overload, reveal_type
|
||||
from typing import Literal, TypeAlias, TypedDict, overload
|
||||
|
||||
from typing_extensions import Unpack
|
||||
from typing_extensions import NotRequired, Unpack
|
||||
|
||||
from crewai.files.uploaders.base import FileUploader
|
||||
from crewai.files.uploaders.anthropic import AnthropicFileUploader
|
||||
from crewai.files.uploaders.bedrock import BedrockFileUploader
|
||||
from crewai.files.uploaders.gemini import GeminiFileUploader
|
||||
from crewai.files.uploaders.openai import OpenAIFileUploader
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ProviderType = Literal[
|
||||
"gemini", "google", "anthropic", "claude", "openai", "gpt", "bedrock", "aws"
|
||||
]
|
||||
UnknownProvider = str
|
||||
FileUploaderType: TypeAlias = (
|
||||
GeminiFileUploader
|
||||
| AnthropicFileUploader
|
||||
| BedrockFileUploader
|
||||
| OpenAIFileUploader
|
||||
)
|
||||
|
||||
GeminiProviderType = Literal["gemini", "google"]
|
||||
AnthropicProviderType = Literal["anthropic", "claude"]
|
||||
OpenAIProviderType = Literal["openai", "gpt"]
|
||||
BedrockProviderType = Literal["bedrock", "aws"]
|
||||
|
||||
ProviderType: TypeAlias = (
|
||||
GeminiProviderType
|
||||
| AnthropicProviderType
|
||||
| OpenAIProviderType
|
||||
| BedrockProviderType
|
||||
)
|
||||
|
||||
|
||||
class _BaseOpts(TypedDict):
|
||||
"""Kwargs for uploader factory."""
|
||||
|
||||
api_key: NotRequired[str | None]
|
||||
|
||||
|
||||
class OpenAIOpts(_BaseOpts):
|
||||
"""Kwargs for openai uploader factory."""
|
||||
|
||||
chunk_size: NotRequired[int]
|
||||
|
||||
|
||||
class GeminiOpts(_BaseOpts):
|
||||
"""Kwargs for gemini uploader factory."""
|
||||
|
||||
|
||||
class AnthropicOpts(_BaseOpts):
|
||||
"""Kwargs for anthropic uploader factory."""
|
||||
|
||||
|
||||
class BedrockOpts(TypedDict):
|
||||
"""Kwargs for bedrock uploader factory."""
|
||||
|
||||
bucket_name: NotRequired[str | None]
|
||||
bucket_owner: NotRequired[str | None]
|
||||
prefix: NotRequired[str]
|
||||
region: NotRequired[str | None]
|
||||
|
||||
|
||||
class AllOptions(TypedDict):
|
||||
"""Kwargs for uploader factory."""
|
||||
|
||||
api_key: str | None
|
||||
chunk_size: int
|
||||
bucket_name: str
|
||||
bucket_owner: str
|
||||
prefix: str
|
||||
region: str
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(provider: UnknownProvider, /) -> None:
|
||||
"""Get file uploader for unknown provider."""
|
||||
api_key: NotRequired[str | None]
|
||||
chunk_size: NotRequired[int]
|
||||
bucket_name: NotRequired[str | None]
|
||||
bucket_owner: NotRequired[str | None]
|
||||
prefix: NotRequired[str]
|
||||
region: NotRequired[str | None]
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(
|
||||
provider: Literal["gemini", "google"],
|
||||
*,
|
||||
api_key: str | None = ...,
|
||||
) -> FileUploader:
|
||||
provider: GeminiProviderType,
|
||||
**kwargs: Unpack[GeminiOpts],
|
||||
) -> GeminiFileUploader:
|
||||
"""Get Gemini file uploader."""
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(
|
||||
provider: Literal["anthropic", "claude"],
|
||||
*,
|
||||
api_key: str | None = ...,
|
||||
) -> FileUploader:
|
||||
provider: AnthropicProviderType,
|
||||
**kwargs: Unpack[AnthropicOpts],
|
||||
) -> AnthropicFileUploader:
|
||||
"""Get Anthropic file uploader."""
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(
|
||||
provider: Literal["openai", "gpt"],
|
||||
*,
|
||||
api_key: str | None = ...,
|
||||
chunk_size: int = ...,
|
||||
) -> FileUploader | None:
|
||||
provider: OpenAIProviderType,
|
||||
**kwargs: Unpack[OpenAIOpts],
|
||||
) -> OpenAIFileUploader:
|
||||
"""Get OpenAI file uploader."""
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(
|
||||
provider: Literal["bedrock", "aws"],
|
||||
*,
|
||||
bucket_name: str | None = ...,
|
||||
bucket_owner: str | None = ...,
|
||||
prefix: str = ...,
|
||||
region: str | None = ...,
|
||||
) -> FileUploader | None:
|
||||
provider: BedrockProviderType,
|
||||
**kwargs: Unpack[BedrockOpts],
|
||||
) -> BedrockFileUploader:
|
||||
"""Get Bedrock file uploader."""
|
||||
|
||||
|
||||
@overload
|
||||
def get_uploader(
|
||||
provider: ProviderType | UnknownProvider, **kwargs: Unpack[AllOptions]
|
||||
) -> FileUploader | None:
|
||||
provider: ProviderType, **kwargs: Unpack[AllOptions]
|
||||
) -> FileUploaderType:
|
||||
"""Get any file uploader."""
|
||||
|
||||
|
||||
def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
||||
def get_uploader(
|
||||
provider: ProviderType, **kwargs: Unpack[AllOptions]
|
||||
) -> FileUploaderType:
|
||||
"""Get a file uploader for a specific provider.
|
||||
|
||||
Args:
|
||||
@@ -103,7 +138,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
||||
logger.warning(
|
||||
"google-genai not installed. Install with: pip install google-genai"
|
||||
)
|
||||
return None
|
||||
raise
|
||||
|
||||
if "anthropic" in provider_lower or "claude" in provider_lower:
|
||||
try:
|
||||
@@ -114,7 +149,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
||||
logger.warning(
|
||||
"anthropic not installed. Install with: pip install anthropic"
|
||||
)
|
||||
return None
|
||||
raise
|
||||
|
||||
if "openai" in provider_lower or "gpt" in provider_lower:
|
||||
try:
|
||||
@@ -126,7 +161,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("openai not installed. Install with: pip install openai")
|
||||
return None
|
||||
raise
|
||||
|
||||
if "bedrock" in provider_lower or "aws" in provider_lower:
|
||||
import os
|
||||
@@ -139,7 +174,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
||||
"Bedrock S3 uploader not configured. "
|
||||
"Set CREWAI_BEDROCK_S3_BUCKET environment variable to enable."
|
||||
)
|
||||
return None
|
||||
raise
|
||||
try:
|
||||
from crewai.files.uploaders.bedrock import BedrockFileUploader
|
||||
|
||||
@@ -151,11 +186,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("boto3 not installed. Install with: pip install boto3")
|
||||
return None
|
||||
raise
|
||||
|
||||
logger.debug(f"No file uploader available for provider: {provider}")
|
||||
return None
|
||||
|
||||
|
||||
t = get_uploader("openai")
|
||||
reveal_type(t)
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user