mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-26 00:28:13 +00:00
refactor: improve factory typing with specific provider and uploader types
This commit is contained in:
@@ -3,86 +3,121 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Literal, TypedDict, overload, reveal_type
|
from typing import Literal, TypeAlias, TypedDict, overload
|
||||||
|
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import NotRequired, Unpack
|
||||||
|
|
||||||
from crewai.files.uploaders.base import FileUploader
|
from crewai.files.uploaders.anthropic import AnthropicFileUploader
|
||||||
|
from crewai.files.uploaders.bedrock import BedrockFileUploader
|
||||||
|
from crewai.files.uploaders.gemini import GeminiFileUploader
|
||||||
|
from crewai.files.uploaders.openai import OpenAIFileUploader
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
ProviderType = Literal[
|
FileUploaderType: TypeAlias = (
|
||||||
"gemini", "google", "anthropic", "claude", "openai", "gpt", "bedrock", "aws"
|
GeminiFileUploader
|
||||||
]
|
| AnthropicFileUploader
|
||||||
UnknownProvider = str
|
| BedrockFileUploader
|
||||||
|
| OpenAIFileUploader
|
||||||
|
)
|
||||||
|
|
||||||
|
GeminiProviderType = Literal["gemini", "google"]
|
||||||
|
AnthropicProviderType = Literal["anthropic", "claude"]
|
||||||
|
OpenAIProviderType = Literal["openai", "gpt"]
|
||||||
|
BedrockProviderType = Literal["bedrock", "aws"]
|
||||||
|
|
||||||
|
ProviderType: TypeAlias = (
|
||||||
|
GeminiProviderType
|
||||||
|
| AnthropicProviderType
|
||||||
|
| OpenAIProviderType
|
||||||
|
| BedrockProviderType
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseOpts(TypedDict):
|
||||||
|
"""Kwargs for uploader factory."""
|
||||||
|
|
||||||
|
api_key: NotRequired[str | None]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIOpts(_BaseOpts):
|
||||||
|
"""Kwargs for openai uploader factory."""
|
||||||
|
|
||||||
|
chunk_size: NotRequired[int]
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiOpts(_BaseOpts):
|
||||||
|
"""Kwargs for gemini uploader factory."""
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicOpts(_BaseOpts):
|
||||||
|
"""Kwargs for anthropic uploader factory."""
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockOpts(TypedDict):
|
||||||
|
"""Kwargs for bedrock uploader factory."""
|
||||||
|
|
||||||
|
bucket_name: NotRequired[str | None]
|
||||||
|
bucket_owner: NotRequired[str | None]
|
||||||
|
prefix: NotRequired[str]
|
||||||
|
region: NotRequired[str | None]
|
||||||
|
|
||||||
|
|
||||||
class AllOptions(TypedDict):
|
class AllOptions(TypedDict):
|
||||||
"""Kwargs for uploader factory."""
|
"""Kwargs for uploader factory."""
|
||||||
|
|
||||||
api_key: str | None
|
api_key: NotRequired[str | None]
|
||||||
chunk_size: int
|
chunk_size: NotRequired[int]
|
||||||
bucket_name: str
|
bucket_name: NotRequired[str | None]
|
||||||
bucket_owner: str
|
bucket_owner: NotRequired[str | None]
|
||||||
prefix: str
|
prefix: NotRequired[str]
|
||||||
region: str
|
region: NotRequired[str | None]
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def get_uploader(provider: UnknownProvider, /) -> None:
|
|
||||||
"""Get file uploader for unknown provider."""
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_uploader(
|
def get_uploader(
|
||||||
provider: Literal["gemini", "google"],
|
provider: GeminiProviderType,
|
||||||
*,
|
**kwargs: Unpack[GeminiOpts],
|
||||||
api_key: str | None = ...,
|
) -> GeminiFileUploader:
|
||||||
) -> FileUploader:
|
|
||||||
"""Get Gemini file uploader."""
|
"""Get Gemini file uploader."""
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_uploader(
|
def get_uploader(
|
||||||
provider: Literal["anthropic", "claude"],
|
provider: AnthropicProviderType,
|
||||||
*,
|
**kwargs: Unpack[AnthropicOpts],
|
||||||
api_key: str | None = ...,
|
) -> AnthropicFileUploader:
|
||||||
) -> FileUploader:
|
|
||||||
"""Get Anthropic file uploader."""
|
"""Get Anthropic file uploader."""
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_uploader(
|
def get_uploader(
|
||||||
provider: Literal["openai", "gpt"],
|
provider: OpenAIProviderType,
|
||||||
*,
|
**kwargs: Unpack[OpenAIOpts],
|
||||||
api_key: str | None = ...,
|
) -> OpenAIFileUploader:
|
||||||
chunk_size: int = ...,
|
|
||||||
) -> FileUploader | None:
|
|
||||||
"""Get OpenAI file uploader."""
|
"""Get OpenAI file uploader."""
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_uploader(
|
def get_uploader(
|
||||||
provider: Literal["bedrock", "aws"],
|
provider: BedrockProviderType,
|
||||||
*,
|
**kwargs: Unpack[BedrockOpts],
|
||||||
bucket_name: str | None = ...,
|
) -> BedrockFileUploader:
|
||||||
bucket_owner: str | None = ...,
|
|
||||||
prefix: str = ...,
|
|
||||||
region: str | None = ...,
|
|
||||||
) -> FileUploader | None:
|
|
||||||
"""Get Bedrock file uploader."""
|
"""Get Bedrock file uploader."""
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_uploader(
|
def get_uploader(
|
||||||
provider: ProviderType | UnknownProvider, **kwargs: Unpack[AllOptions]
|
provider: ProviderType, **kwargs: Unpack[AllOptions]
|
||||||
) -> FileUploader | None:
|
) -> FileUploaderType:
|
||||||
"""Get any file uploader."""
|
"""Get any file uploader."""
|
||||||
|
|
||||||
|
|
||||||
def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
def get_uploader(
|
||||||
|
provider: ProviderType, **kwargs: Unpack[AllOptions]
|
||||||
|
) -> FileUploaderType:
|
||||||
"""Get a file uploader for a specific provider.
|
"""Get a file uploader for a specific provider.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -103,7 +138,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"google-genai not installed. Install with: pip install google-genai"
|
"google-genai not installed. Install with: pip install google-genai"
|
||||||
)
|
)
|
||||||
return None
|
raise
|
||||||
|
|
||||||
if "anthropic" in provider_lower or "claude" in provider_lower:
|
if "anthropic" in provider_lower or "claude" in provider_lower:
|
||||||
try:
|
try:
|
||||||
@@ -114,7 +149,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"anthropic not installed. Install with: pip install anthropic"
|
"anthropic not installed. Install with: pip install anthropic"
|
||||||
)
|
)
|
||||||
return None
|
raise
|
||||||
|
|
||||||
if "openai" in provider_lower or "gpt" in provider_lower:
|
if "openai" in provider_lower or "gpt" in provider_lower:
|
||||||
try:
|
try:
|
||||||
@@ -126,7 +161,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
|||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("openai not installed. Install with: pip install openai")
|
logger.warning("openai not installed. Install with: pip install openai")
|
||||||
return None
|
raise
|
||||||
|
|
||||||
if "bedrock" in provider_lower or "aws" in provider_lower:
|
if "bedrock" in provider_lower or "aws" in provider_lower:
|
||||||
import os
|
import os
|
||||||
@@ -139,7 +174,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
|||||||
"Bedrock S3 uploader not configured. "
|
"Bedrock S3 uploader not configured. "
|
||||||
"Set CREWAI_BEDROCK_S3_BUCKET environment variable to enable."
|
"Set CREWAI_BEDROCK_S3_BUCKET environment variable to enable."
|
||||||
)
|
)
|
||||||
return None
|
raise
|
||||||
try:
|
try:
|
||||||
from crewai.files.uploaders.bedrock import BedrockFileUploader
|
from crewai.files.uploaders.bedrock import BedrockFileUploader
|
||||||
|
|
||||||
@@ -151,11 +186,7 @@ def get_uploader(provider, **kwargs): # type: ignore[no-untyped-def]
|
|||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("boto3 not installed. Install with: pip install boto3")
|
logger.warning("boto3 not installed. Install with: pip install boto3")
|
||||||
return None
|
raise
|
||||||
|
|
||||||
logger.debug(f"No file uploader available for provider: {provider}")
|
logger.debug(f"No file uploader available for provider: {provider}")
|
||||||
return None
|
raise
|
||||||
|
|
||||||
|
|
||||||
t = get_uploader("openai")
|
|
||||||
reveal_type(t)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user