mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
feat: allow LLM providers to pass clients to file uploaders
- Add get_file_uploader() method to BaseLLM (returns None by default) - Implement get_file_uploader() in Anthropic, OpenAI, Gemini, Bedrock - Pass both sync and async clients where applicable - Update uploaders to accept optional pre-instantiated clients - Update factory to pass through client parameters This allows reusing authenticated LLM clients for file uploads, avoiding redundant connections.
This commit is contained in:
@@ -22,16 +22,23 @@ class AnthropicFileUploader(FileUploader):
|
||||
until explicitly deleted.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
client: Any = None,
|
||||
async_client: Any = None,
|
||||
) -> None:
|
||||
"""Initialize the Anthropic uploader.
|
||||
|
||||
Args:
|
||||
api_key: Optional Anthropic API key. If not provided, uses
|
||||
ANTHROPIC_API_KEY environment variable.
|
||||
client: Optional pre-instantiated Anthropic client.
|
||||
async_client: Optional pre-instantiated async Anthropic client.
|
||||
"""
|
||||
self._api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
||||
self._client: Any = None
|
||||
self._async_client: Any = None
|
||||
self._client: Any = client
|
||||
self._async_client: Any = async_client
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
|
||||
@@ -110,6 +110,8 @@ class BedrockFileUploader(FileUploader):
|
||||
bucket_owner: str | None = None,
|
||||
prefix: str = "crewai-files",
|
||||
region: str | None = None,
|
||||
client: Any = None,
|
||||
async_client: Any = None,
|
||||
) -> None:
|
||||
"""Initialize the Bedrock S3 uploader.
|
||||
|
||||
@@ -120,6 +122,8 @@ class BedrockFileUploader(FileUploader):
|
||||
Uses CREWAI_BEDROCK_S3_BUCKET_OWNER environment variable if not provided.
|
||||
prefix: S3 key prefix for uploaded files (default: "crewai-files").
|
||||
region: AWS region. Uses AWS_REGION or AWS_DEFAULT_REGION if not provided.
|
||||
client: Optional pre-instantiated boto3 S3 client.
|
||||
async_client: Optional pre-instantiated aioboto3 S3 client.
|
||||
"""
|
||||
self._bucket_name = bucket_name or os.environ.get("CREWAI_BEDROCK_S3_BUCKET")
|
||||
self._bucket_owner = bucket_owner or os.environ.get(
|
||||
@@ -129,8 +133,8 @@ class BedrockFileUploader(FileUploader):
|
||||
self._region = region or os.environ.get(
|
||||
"AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")
|
||||
)
|
||||
self._client: Any = None
|
||||
self._async_client: Any = None
|
||||
self._client: Any = client
|
||||
self._async_client: Any = async_client
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
|
||||
@@ -36,10 +36,15 @@ ProviderType: TypeAlias = (
|
||||
)
|
||||
|
||||
|
||||
from typing import Any as AnyType
|
||||
|
||||
|
||||
class _BaseOpts(TypedDict):
|
||||
"""Kwargs for uploader factory."""
|
||||
|
||||
api_key: NotRequired[str | None]
|
||||
client: NotRequired[AnyType]
|
||||
async_client: NotRequired[AnyType]
|
||||
|
||||
|
||||
class OpenAIOpts(_BaseOpts):
|
||||
@@ -48,9 +53,12 @@ class OpenAIOpts(_BaseOpts):
|
||||
chunk_size: NotRequired[int]
|
||||
|
||||
|
||||
class GeminiOpts(_BaseOpts):
|
||||
class GeminiOpts(TypedDict):
|
||||
"""Kwargs for gemini uploader factory."""
|
||||
|
||||
api_key: NotRequired[str | None]
|
||||
client: NotRequired[AnyType]
|
||||
|
||||
|
||||
class AnthropicOpts(_BaseOpts):
|
||||
"""Kwargs for anthropic uploader factory."""
|
||||
@@ -63,6 +71,8 @@ class BedrockOpts(TypedDict):
|
||||
bucket_owner: NotRequired[str | None]
|
||||
prefix: NotRequired[str]
|
||||
region: NotRequired[str | None]
|
||||
client: NotRequired[AnyType]
|
||||
async_client: NotRequired[AnyType]
|
||||
|
||||
|
||||
class AllOptions(TypedDict):
|
||||
@@ -74,6 +84,8 @@ class AllOptions(TypedDict):
|
||||
bucket_owner: NotRequired[str | None]
|
||||
prefix: NotRequired[str]
|
||||
region: NotRequired[str | None]
|
||||
client: NotRequired[AnyType]
|
||||
async_client: NotRequired[AnyType]
|
||||
|
||||
|
||||
@overload
|
||||
@@ -133,7 +145,10 @@ def get_uploader(
|
||||
try:
|
||||
from crewai_files.uploaders.gemini import GeminiFileUploader
|
||||
|
||||
return GeminiFileUploader(api_key=kwargs.get("api_key"))
|
||||
return GeminiFileUploader(
|
||||
api_key=kwargs.get("api_key"),
|
||||
client=kwargs.get("client"),
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"google-genai not installed. Install with: pip install google-genai"
|
||||
@@ -144,7 +159,11 @@ def get_uploader(
|
||||
try:
|
||||
from crewai_files.uploaders.anthropic import AnthropicFileUploader
|
||||
|
||||
return AnthropicFileUploader(api_key=kwargs.get("api_key"))
|
||||
return AnthropicFileUploader(
|
||||
api_key=kwargs.get("api_key"),
|
||||
client=kwargs.get("client"),
|
||||
async_client=kwargs.get("async_client"),
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"anthropic not installed. Install with: pip install anthropic"
|
||||
@@ -162,6 +181,8 @@ def get_uploader(
|
||||
return OpenAIFileUploader(
|
||||
api_key=kwargs.get("api_key"),
|
||||
chunk_size=kwargs.get("chunk_size", 67_108_864),
|
||||
client=kwargs.get("client"),
|
||||
async_client=kwargs.get("async_client"),
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("openai not installed. Install with: pip install openai")
|
||||
@@ -187,6 +208,8 @@ def get_uploader(
|
||||
bucket_owner=kwargs.get("bucket_owner"),
|
||||
prefix=kwargs.get("prefix", "crewai-files"),
|
||||
region=kwargs.get("region"),
|
||||
client=kwargs.get("client"),
|
||||
async_client=kwargs.get("async_client"),
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("boto3 not installed. Install with: pip install boto3")
|
||||
|
||||
@@ -93,15 +93,20 @@ class GeminiFileUploader(FileUploader):
|
||||
Uses the google-genai SDK to upload files. Files are stored for 48 hours.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
client: Any = None,
|
||||
) -> None:
|
||||
"""Initialize the Gemini uploader.
|
||||
|
||||
Args:
|
||||
api_key: Optional Google API key. If not provided, uses
|
||||
GOOGLE_API_KEY environment variable.
|
||||
client: Optional pre-instantiated Gemini client.
|
||||
"""
|
||||
self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
|
||||
self._client: Any = None
|
||||
self._client: Any = client
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
|
||||
@@ -100,6 +100,8 @@ class OpenAIFileUploader(FileUploader):
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
chunk_size: int = DEFAULT_UPLOAD_CHUNK_SIZE,
|
||||
client: Any = None,
|
||||
async_client: Any = None,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI uploader.
|
||||
|
||||
@@ -107,11 +109,13 @@ class OpenAIFileUploader(FileUploader):
|
||||
api_key: Optional OpenAI API key. If not provided, uses
|
||||
OPENAI_API_KEY environment variable.
|
||||
chunk_size: Chunk size in bytes for multipart uploads (default 64MB).
|
||||
client: Optional pre-instantiated OpenAI client.
|
||||
async_client: Optional pre-instantiated async OpenAI client.
|
||||
"""
|
||||
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
self._chunk_size = chunk_size
|
||||
self._client: Any = None
|
||||
self._async_client: Any = None
|
||||
self._client: Any = client
|
||||
self._async_client: Any = async_client
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user