feat: allow LLM providers to pass clients to file uploaders

- Add get_file_uploader() method to BaseLLM (returns None by default)
- Implement get_file_uploader() in Anthropic, OpenAI, Gemini, Bedrock
- Pass both sync and async clients where applicable
- Update uploaders to accept optional pre-instantiated clients
- Update factory to pass through client parameters

This allows reusing authenticated LLM clients for file uploads,
avoiding redundant connections.
This commit is contained in:
Greyson LaLonde
2026-01-22 22:44:05 -05:00
parent 9a2b610b21
commit 2c5e794ea3
10 changed files with 139 additions and 12 deletions

View File

@@ -22,16 +22,23 @@ class AnthropicFileUploader(FileUploader):
until explicitly deleted.
"""
def __init__(self, api_key: str | None = None) -> None:
def __init__(
self,
api_key: str | None = None,
client: Any = None,
async_client: Any = None,
) -> None:
"""Initialize the Anthropic uploader.
Args:
api_key: Optional Anthropic API key. If not provided, uses
ANTHROPIC_API_KEY environment variable.
client: Optional pre-instantiated Anthropic client.
async_client: Optional pre-instantiated async Anthropic client.
"""
self._api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
self._client: Any = None
self._async_client: Any = None
self._client: Any = client
self._async_client: Any = async_client
@property
def provider_name(self) -> str:

View File

@@ -110,6 +110,8 @@ class BedrockFileUploader(FileUploader):
bucket_owner: str | None = None,
prefix: str = "crewai-files",
region: str | None = None,
client: Any = None,
async_client: Any = None,
) -> None:
"""Initialize the Bedrock S3 uploader.
@@ -120,6 +122,8 @@ class BedrockFileUploader(FileUploader):
Uses CREWAI_BEDROCK_S3_BUCKET_OWNER environment variable if not provided.
prefix: S3 key prefix for uploaded files (default: "crewai-files").
region: AWS region. Uses AWS_REGION or AWS_DEFAULT_REGION if not provided.
client: Optional pre-instantiated boto3 S3 client.
async_client: Optional pre-instantiated aioboto3 S3 client.
"""
self._bucket_name = bucket_name or os.environ.get("CREWAI_BEDROCK_S3_BUCKET")
self._bucket_owner = bucket_owner or os.environ.get(
@@ -129,8 +133,8 @@ class BedrockFileUploader(FileUploader):
self._region = region or os.environ.get(
"AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")
)
self._client: Any = None
self._async_client: Any = None
self._client: Any = client
self._async_client: Any = async_client
@property
def provider_name(self) -> str:

View File

@@ -36,10 +36,15 @@ ProviderType: TypeAlias = (
)
from typing import Any as AnyType
class _BaseOpts(TypedDict):
"""Kwargs for uploader factory."""
api_key: NotRequired[str | None]
client: NotRequired[AnyType]
async_client: NotRequired[AnyType]
class OpenAIOpts(_BaseOpts):
@@ -48,9 +53,12 @@ class OpenAIOpts(_BaseOpts):
chunk_size: NotRequired[int]
class GeminiOpts(_BaseOpts):
class GeminiOpts(TypedDict):
"""Kwargs for gemini uploader factory."""
api_key: NotRequired[str | None]
client: NotRequired[AnyType]
class AnthropicOpts(_BaseOpts):
"""Kwargs for anthropic uploader factory."""
@@ -63,6 +71,8 @@ class BedrockOpts(TypedDict):
bucket_owner: NotRequired[str | None]
prefix: NotRequired[str]
region: NotRequired[str | None]
client: NotRequired[AnyType]
async_client: NotRequired[AnyType]
class AllOptions(TypedDict):
@@ -74,6 +84,8 @@ class AllOptions(TypedDict):
bucket_owner: NotRequired[str | None]
prefix: NotRequired[str]
region: NotRequired[str | None]
client: NotRequired[AnyType]
async_client: NotRequired[AnyType]
@overload
@@ -133,7 +145,10 @@ def get_uploader(
try:
from crewai_files.uploaders.gemini import GeminiFileUploader
return GeminiFileUploader(api_key=kwargs.get("api_key"))
return GeminiFileUploader(
api_key=kwargs.get("api_key"),
client=kwargs.get("client"),
)
except ImportError:
logger.warning(
"google-genai not installed. Install with: pip install google-genai"
@@ -144,7 +159,11 @@ def get_uploader(
try:
from crewai_files.uploaders.anthropic import AnthropicFileUploader
return AnthropicFileUploader(api_key=kwargs.get("api_key"))
return AnthropicFileUploader(
api_key=kwargs.get("api_key"),
client=kwargs.get("client"),
async_client=kwargs.get("async_client"),
)
except ImportError:
logger.warning(
"anthropic not installed. Install with: pip install anthropic"
@@ -162,6 +181,8 @@ def get_uploader(
return OpenAIFileUploader(
api_key=kwargs.get("api_key"),
chunk_size=kwargs.get("chunk_size", 67_108_864),
client=kwargs.get("client"),
async_client=kwargs.get("async_client"),
)
except ImportError:
logger.warning("openai not installed. Install with: pip install openai")
@@ -187,6 +208,8 @@ def get_uploader(
bucket_owner=kwargs.get("bucket_owner"),
prefix=kwargs.get("prefix", "crewai-files"),
region=kwargs.get("region"),
client=kwargs.get("client"),
async_client=kwargs.get("async_client"),
)
except ImportError:
logger.warning("boto3 not installed. Install with: pip install boto3")

View File

@@ -93,15 +93,20 @@ class GeminiFileUploader(FileUploader):
Uses the google-genai SDK to upload files. Files are stored for 48 hours.
"""
def __init__(self, api_key: str | None = None) -> None:
def __init__(
self,
api_key: str | None = None,
client: Any = None,
) -> None:
"""Initialize the Gemini uploader.
Args:
api_key: Optional Google API key. If not provided, uses
GOOGLE_API_KEY environment variable.
client: Optional pre-instantiated Gemini client.
"""
self._api_key = api_key or os.environ.get("GOOGLE_API_KEY")
self._client: Any = None
self._client: Any = client
@property
def provider_name(self) -> str:

View File

@@ -100,6 +100,8 @@ class OpenAIFileUploader(FileUploader):
self,
api_key: str | None = None,
chunk_size: int = DEFAULT_UPLOAD_CHUNK_SIZE,
client: Any = None,
async_client: Any = None,
) -> None:
"""Initialize the OpenAI uploader.
@@ -107,11 +109,13 @@ class OpenAIFileUploader(FileUploader):
api_key: Optional OpenAI API key. If not provided, uses
OPENAI_API_KEY environment variable.
chunk_size: Chunk size in bytes for multipart uploads (default 64MB).
client: Optional pre-instantiated OpenAI client.
async_client: Optional pre-instantiated async OpenAI client.
"""
self._api_key = api_key or os.environ.get("OPENAI_API_KEY")
self._chunk_size = chunk_size
self._client: Any = None
self._async_client: Any = None
self._client: Any = client
self._async_client: Any = async_client
@property
def provider_name(self) -> str: