fix: defer native LLM client construction when credentials are missing

All native LLM providers built their SDK clients inside
`@model_validator(mode="after")`, which required the API key at
`LLM(...)` construction time. Instantiating an LLM at module scope
(e.g. `chat_llm=LLM(model="openai/gpt-4o-mini")` on a `@crew` method)
crashed during downstream crew-metadata extraction with a confusing
`ImportError: Error importing native provider: 1 validation error...`
before the process env vars were ever consulted.

Wrap eager client construction in a try/except in each provider and
add `_get_sync_client` / `_get_async_client` methods that build on
first use. OpenAI call sites are routed through the lazy getters so
calls made without eager construction still work. The descriptive
"X_API_KEY is required" errors are re-raised from the lazy path at
first real call.

Update two Azure tests that asserted the old eager-error contract to
assert the new lazy-error contract.
This commit is contained in:
Greyson LaLonde
2026-04-11 05:48:42 +08:00
parent 298fc7b9c0
commit 851df79a82
6 changed files with 209 additions and 91 deletions

View File

@@ -189,16 +189,37 @@ class AnthropicCompletion(BaseLLM):
@model_validator(mode="after")
def _init_clients(self) -> AnthropicCompletion:
self._client = Anthropic(**self._get_client_params())
"""Eagerly build clients when the API key is available, otherwise
defer so ``LLM(model="anthropic/...")`` can be constructed at module
import time even before deployment env vars are set.
"""
try:
self._client = self._build_sync_client()
self._async_client = self._build_async_client()
except ValueError:
pass
return self
def _build_sync_client(self) -> Any:
return Anthropic(**self._get_client_params())
def _build_async_client(self) -> Any:
async_client_params = self._get_client_params()
if self.interceptor:
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
async_http_client = httpx.AsyncClient(transport=async_transport)
async_client_params["http_client"] = async_http_client
return AsyncAnthropic(**async_client_params)
self._async_client = AsyncAnthropic(**async_client_params)
return self
def _get_sync_client(self) -> Any:
if self._client is None:
self._client = self._build_sync_client()
return self._client
def _get_async_client(self) -> Any:
if self._async_client is None:
self._async_client = self._build_async_client()
return self._async_client
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Anthropic-specific fields."""
@@ -790,11 +811,11 @@ class AnthropicCompletion(BaseLLM):
try:
if betas:
params["betas"] = betas
response = self._client.beta.messages.create(
response = self._get_sync_client().beta.messages.create(
**params, extra_body=extra_body
)
else:
response = self._client.messages.create(**params)
response = self._get_sync_client().messages.create(**params)
except Exception as e:
if is_context_length_exceeded(e):
@@ -942,9 +963,11 @@ class AnthropicCompletion(BaseLLM):
current_tool_calls: dict[int, dict[str, Any]] = {}
stream_context = (
self._client.beta.messages.stream(**stream_params, extra_body=extra_body)
self._get_sync_client().beta.messages.stream(
**stream_params, extra_body=extra_body
)
if betas
else self._client.messages.stream(**stream_params)
else self._get_sync_client().messages.stream(**stream_params)
)
with stream_context as stream:
response_id = None
@@ -1223,7 +1246,9 @@ class AnthropicCompletion(BaseLLM):
try:
# Send tool results back to Claude for final response
final_response: Message = self._client.messages.create(**follow_up_params)
final_response: Message = self._get_sync_client().messages.create(
**follow_up_params
)
# Track token usage for follow-up call
follow_up_usage = self._extract_anthropic_token_usage(final_response)
@@ -1319,11 +1344,11 @@ class AnthropicCompletion(BaseLLM):
try:
if betas:
params["betas"] = betas
response = await self._async_client.beta.messages.create(
response = await self._get_async_client().beta.messages.create(
**params, extra_body=extra_body
)
else:
response = await self._async_client.messages.create(**params)
response = await self._get_async_client().messages.create(**params)
except Exception as e:
if is_context_length_exceeded(e):
@@ -1457,11 +1482,11 @@ class AnthropicCompletion(BaseLLM):
current_tool_calls: dict[int, dict[str, Any]] = {}
stream_context = (
self._async_client.beta.messages.stream(
self._get_async_client().beta.messages.stream(
**stream_params, extra_body=extra_body
)
if betas
else self._async_client.messages.stream(**stream_params)
else self._get_async_client().messages.stream(**stream_params)
)
async with stream_context as stream:
response_id = None
@@ -1626,7 +1651,7 @@ class AnthropicCompletion(BaseLLM):
]
try:
final_response: Message = await self._async_client.messages.create(
final_response: Message = await self._get_async_client().messages.create(
**follow_up_params
)
@@ -1754,8 +1779,8 @@ class AnthropicCompletion(BaseLLM):
from crewai_files.uploaders.anthropic import AnthropicFileUploader
return AnthropicFileUploader(
client=self._client,
async_client=self._async_client,
client=self._get_sync_client(),
async_client=self._get_async_client(),
)
except ImportError:
return None

View File

@@ -116,43 +116,72 @@ class AzureCompletion(BaseLLM):
data.get("api_version") or os.getenv("AZURE_API_VERSION") or "2024-06-01"
)
if not data["api_key"]:
raise ValueError(
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
)
if not data["endpoint"]:
raise ValueError(
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
)
# Credentials and endpoint are validated lazily in `_init_clients`
# so the LLM can be constructed before deployment env vars are set.
model = data.get("model", "")
data["endpoint"] = AzureCompletion._validate_and_fix_endpoint(
data["endpoint"], model
)
if data["endpoint"]:
data["endpoint"] = AzureCompletion._validate_and_fix_endpoint(
data["endpoint"], model
)
parsed = urlparse(data["endpoint"])
hostname = parsed.hostname or ""
data["is_azure_openai_endpoint"] = (
hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")
) and "/openai/deployments/" in data["endpoint"]
else:
data["is_azure_openai_endpoint"] = False
data["is_openai_model"] = any(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
)
parsed = urlparse(data["endpoint"])
hostname = parsed.hostname or ""
data["is_azure_openai_endpoint"] = (
hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")
) and "/openai/deployments/" in data["endpoint"]
return data
@model_validator(mode="after")
def _init_clients(self) -> AzureCompletion:
"""Eagerly build clients when credentials are available, otherwise
defer so ``LLM(model="azure/...")`` can be constructed at module
import time even before deployment env vars are set.
"""
try:
self._client = self._build_sync_client()
self._async_client = self._build_async_client()
except ValueError:
pass
return self
def _build_sync_client(self) -> Any:
return ChatCompletionsClient(**self._make_client_kwargs())
def _build_async_client(self) -> Any:
return AsyncChatCompletionsClient(**self._make_client_kwargs())
def _make_client_kwargs(self) -> dict[str, Any]:
if not self.api_key:
raise ValueError("Azure API key is required.")
raise ValueError(
"Azure API key is required. Set AZURE_API_KEY environment "
"variable or pass api_key parameter."
)
if not self.endpoint:
raise ValueError(
"Azure endpoint is required. Set AZURE_ENDPOINT environment "
"variable or pass endpoint parameter."
)
client_kwargs: dict[str, Any] = {
"endpoint": self.endpoint,
"credential": AzureKeyCredential(self.api_key),
}
if self.api_version:
client_kwargs["api_version"] = self.api_version
return client_kwargs
self._client = ChatCompletionsClient(**client_kwargs)
self._async_client = AsyncChatCompletionsClient(**client_kwargs)
return self
def _get_sync_client(self) -> Any:
if self._client is None:
self._client = self._build_sync_client()
return self._client
def _get_async_client(self) -> Any:
if self._async_client is None:
self._async_client = self._build_async_client()
return self._async_client
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Azure-specific fields."""
@@ -713,8 +742,7 @@ class AzureCompletion(BaseLLM):
) -> str | Any:
"""Handle non-streaming chat completion."""
try:
# Cast params to Any to avoid type checking issues with TypedDict unpacking
response: ChatCompletions = self._client.complete(**params)
response: ChatCompletions = self._get_sync_client().complete(**params)
return self._process_completion_response(
response=response,
params=params,
@@ -913,7 +941,7 @@ class AzureCompletion(BaseLLM):
tool_calls: dict[int, dict[str, Any]] = {}
usage_data: dict[str, Any] | None = None
for update in self._client.complete(**params):
for update in self._get_sync_client().complete(**params):
if isinstance(update, StreamingChatCompletionsUpdate):
if update.usage:
usage = update.usage
@@ -953,8 +981,9 @@ class AzureCompletion(BaseLLM):
) -> str | Any:
"""Handle non-streaming chat completion asynchronously."""
try:
# Cast params to Any to avoid type checking issues with TypedDict unpacking
response: ChatCompletions = await self._async_client.complete(**params)
response: ChatCompletions = await self._get_async_client().complete(
**params
)
return self._process_completion_response(
response=response,
params=params,
@@ -980,7 +1009,7 @@ class AzureCompletion(BaseLLM):
usage_data: dict[str, Any] | None = None
stream = await self._async_client.complete(**params)
stream = await self._get_async_client().complete(**params)
async for update in stream:
if isinstance(update, StreamingChatCompletionsUpdate):
if hasattr(update, "usage") and update.usage:
@@ -1105,8 +1134,8 @@ class AzureCompletion(BaseLLM):
This ensures proper cleanup of the underlying aiohttp session
to avoid unclosed connector warnings.
"""
if hasattr(self._async_client, "close"):
await self._async_client.close()
if hasattr(self._get_async_client(), "close"):
await self._get_async_client().close()
async def __aenter__(self) -> Self:
"""Async context manager entry."""

View File

@@ -303,6 +303,18 @@ class BedrockCompletion(BaseLLM):
@model_validator(mode="after")
def _init_clients(self) -> BedrockCompletion:
"""Eagerly build the sync client when AWS credentials resolve,
otherwise defer so ``LLM(model="bedrock/...")`` can be constructed
at module import time even before deployment env vars are set.
"""
try:
self._client = self._build_sync_client()
except Exception as e:
logging.debug("Deferring Bedrock client construction: %s", e)
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
return self
def _build_sync_client(self) -> Any:
config = Config(
read_timeout=300,
retries={"max_attempts": 3, "mode": "adaptive"},
@@ -314,9 +326,17 @@ class BedrockCompletion(BaseLLM):
aws_session_token=self.aws_session_token,
region_name=self.region_name,
)
self._client = session.client("bedrock-runtime", config=config)
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
return self
return session.client("bedrock-runtime", config=config)
def _get_sync_client(self) -> Any:
if self._client is None:
self._client = self._build_sync_client()
return self._client
def _get_async_client(self) -> Any:
"""Async client is set up separately by ``_ensure_async_client``
using ``aiobotocore`` inside an exit stack."""
return self._async_client
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Bedrock-specific fields."""
@@ -656,7 +676,7 @@ class BedrockCompletion(BaseLLM):
raise ValueError(f"Invalid message format at index {i}")
# Call Bedrock Converse API with proper error handling
response = self._client.converse(
response = self._get_sync_client().converse(
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
@@ -945,7 +965,7 @@ class BedrockCompletion(BaseLLM):
usage_data: dict[str, Any] | None = None
try:
response = self._client.converse_stream(
response = self._get_sync_client().converse_stream(
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
@@ -1174,7 +1194,7 @@ class BedrockCompletion(BaseLLM):
)
self._async_client = client
self._async_client_initialized = True
return self._async_client
return self._get_async_client()
async def _ahandle_converse(
self,

View File

@@ -118,9 +118,25 @@ class GeminiCompletion(BaseLLM):
@model_validator(mode="after")
def _init_client(self) -> GeminiCompletion:
self._client = self._initialize_client(self.use_vertexai)
"""Eagerly build the client when credentials resolve, otherwise defer
so ``LLM(model="gemini/...")`` can be constructed at module import time
even before deployment env vars are set.
"""
try:
self._client = self._initialize_client(self.use_vertexai)
except ValueError:
pass
return self
def _get_sync_client(self) -> Any:
if self._client is None:
self._client = self._initialize_client(self.use_vertexai)
return self._client
def _get_async_client(self) -> Any:
"""Gemini uses a single client for both sync and async calls."""
return self._get_sync_client()
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Gemini/Vertex-specific fields."""
config = super().to_config_dict()
@@ -228,8 +244,8 @@ class GeminiCompletion(BaseLLM):
if (
hasattr(self, "client")
and hasattr(self._client, "vertexai")
and self._client.vertexai
and hasattr(self._get_sync_client(), "vertexai")
and self._get_sync_client().vertexai
):
# Vertex AI configuration
params.update(
@@ -1112,7 +1128,7 @@ class GeminiCompletion(BaseLLM):
try:
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
response = self._client.models.generate_content(
response = self._get_sync_client().models.generate_content(
model=self.model,
contents=contents_for_api,
config=config,
@@ -1153,7 +1169,7 @@ class GeminiCompletion(BaseLLM):
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
for chunk in self._client.models.generate_content_stream(
for chunk in self._get_sync_client().models.generate_content_stream(
model=self.model,
contents=contents_for_api,
config=config,
@@ -1191,7 +1207,7 @@ class GeminiCompletion(BaseLLM):
try:
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
response = await self._client.aio.models.generate_content(
response = await self._get_sync_client().aio.models.generate_content(
model=self.model,
contents=contents_for_api,
config=config,
@@ -1232,7 +1248,7 @@ class GeminiCompletion(BaseLLM):
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
stream = await self._client.aio.models.generate_content_stream(
stream = await self._get_sync_client().aio.models.generate_content_stream(
model=self.model,
contents=contents_for_api,
config=config,
@@ -1439,6 +1455,6 @@ class GeminiCompletion(BaseLLM):
try:
from crewai_files.uploaders.gemini import GeminiFileUploader
return GeminiFileUploader(client=self._client)
return GeminiFileUploader(client=self._get_sync_client())
except ImportError:
return None

View File

@@ -253,22 +253,40 @@ class OpenAICompletion(BaseLLM):
@model_validator(mode="after")
def _init_clients(self) -> OpenAICompletion:
"""Eagerly build clients when the API key is available, otherwise
defer so ``LLM(model="openai/...")`` can be constructed at module
import time even before deployment env vars are set.
"""
try:
self._client = self._build_sync_client()
self._async_client = self._build_async_client()
except ValueError:
pass
return self
def _build_sync_client(self) -> Any:
client_config = self._get_client_params()
if self.interceptor:
transport = HTTPTransport(interceptor=self.interceptor)
http_client = httpx.Client(transport=transport)
client_config["http_client"] = http_client
client_config["http_client"] = httpx.Client(transport=transport)
return OpenAI(**client_config)
self._client = OpenAI(**client_config)
async_client_config = self._get_client_params()
def _build_async_client(self) -> Any:
client_config = self._get_client_params()
if self.interceptor:
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
async_http_client = httpx.AsyncClient(transport=async_transport)
async_client_config["http_client"] = async_http_client
transport = AsyncHTTPTransport(interceptor=self.interceptor)
client_config["http_client"] = httpx.AsyncClient(transport=transport)
return AsyncOpenAI(**client_config)
self._async_client = AsyncOpenAI(**async_client_config)
return self
def _get_sync_client(self) -> Any:
if self._client is None:
self._client = self._build_sync_client()
return self._client
def _get_async_client(self) -> Any:
if self._async_client is None:
self._async_client = self._build_async_client()
return self._async_client
@property
def last_response_id(self) -> str | None:
@@ -797,7 +815,7 @@ class OpenAICompletion(BaseLLM):
) -> str | ResponsesAPIResult | Any:
"""Handle non-streaming Responses API call."""
try:
response: Response = self._client.responses.create(**params)
response: Response = self._get_sync_client().responses.create(**params)
# Track response ID for auto-chaining
if self.auto_chain and response.id:
@@ -933,7 +951,9 @@ class OpenAICompletion(BaseLLM):
) -> str | ResponsesAPIResult | Any:
"""Handle async non-streaming Responses API call."""
try:
response: Response = await self._async_client.responses.create(**params)
response: Response = await self._get_async_client().responses.create(
**params
)
# Track response ID for auto-chaining
if self.auto_chain and response.id:
@@ -1069,7 +1089,7 @@ class OpenAICompletion(BaseLLM):
final_response: Response | None = None
usage: dict[str, Any] | None = None
stream = self._client.responses.create(**params)
stream = self._get_sync_client().responses.create(**params)
response_id_stream = None
for event in stream:
@@ -1197,7 +1217,7 @@ class OpenAICompletion(BaseLLM):
final_response: Response | None = None
usage: dict[str, Any] | None = None
stream = await self._async_client.responses.create(**params)
stream = await self._get_async_client().responses.create(**params)
response_id_stream = None
async for event in stream:
@@ -1591,7 +1611,7 @@ class OpenAICompletion(BaseLLM):
parse_params = {
k: v for k, v in params.items() if k != "response_format"
}
parsed_response = self._client.beta.chat.completions.parse(
parsed_response = self._get_sync_client().beta.chat.completions.parse(
**parse_params,
response_format=response_model,
)
@@ -1615,7 +1635,9 @@ class OpenAICompletion(BaseLLM):
)
return parsed_object
response: ChatCompletion = self._client.chat.completions.create(**params)
response: ChatCompletion = self._get_sync_client().chat.completions.create(
**params
)
usage = self._extract_openai_token_usage(response)
@@ -1842,7 +1864,7 @@ class OpenAICompletion(BaseLLM):
}
stream: ChatCompletionStream[BaseModel]
with self._client.beta.chat.completions.stream(
with self._get_sync_client().beta.chat.completions.stream(
**parse_params, response_format=response_model
) as stream:
for chunk in stream:
@@ -1879,7 +1901,7 @@ class OpenAICompletion(BaseLLM):
return ""
completion_stream: Stream[ChatCompletionChunk] = (
self._client.chat.completions.create(**params)
self._get_sync_client().chat.completions.create(**params)
)
usage_data: dict[str, Any] | None = None
@@ -1976,9 +1998,11 @@ class OpenAICompletion(BaseLLM):
parse_params = {
k: v for k, v in params.items() if k != "response_format"
}
parsed_response = await self._async_client.beta.chat.completions.parse(
**parse_params,
response_format=response_model,
parsed_response = (
await self._get_async_client().beta.chat.completions.parse(
**parse_params,
response_format=response_model,
)
)
math_reasoning = parsed_response.choices[0].message
@@ -2000,8 +2024,8 @@ class OpenAICompletion(BaseLLM):
)
return parsed_object
response: ChatCompletion = await self._async_client.chat.completions.create(
**params
response: ChatCompletion = (
await self._get_async_client().chat.completions.create(**params)
)
usage = self._extract_openai_token_usage(response)
@@ -2127,7 +2151,7 @@ class OpenAICompletion(BaseLLM):
if response_model:
completion_stream: AsyncIterator[
ChatCompletionChunk
] = await self._async_client.chat.completions.create(**params)
] = await self._get_async_client().chat.completions.create(**params)
accumulated_content = ""
usage_data: dict[str, Any] | None = None
@@ -2183,7 +2207,7 @@ class OpenAICompletion(BaseLLM):
stream: AsyncIterator[
ChatCompletionChunk
] = await self._async_client.chat.completions.create(**params)
] = await self._get_async_client().chat.completions.create(**params)
usage_data = None
@@ -2379,8 +2403,8 @@ class OpenAICompletion(BaseLLM):
from crewai_files.uploaders.openai import OpenAIFileUploader
return OpenAIFileUploader(
client=self._client,
async_client=self._async_client,
client=self._get_sync_client(),
async_client=self._get_async_client(),
)
except ImportError:
return None

View File

@@ -378,23 +378,27 @@ def test_azure_completion_with_tools():
def test_azure_raises_error_when_endpoint_missing():
"""Test that AzureCompletion raises ValueError when endpoint is missing"""
"""Credentials are validated lazily: construction succeeds, first
client build raises the descriptive error."""
from crewai.llms.providers.azure.completion import AzureCompletion
# Clear environment variables
with patch.dict(os.environ, {}, clear=True):
llm = AzureCompletion(model="gpt-4", api_key="test-key")
with pytest.raises(ValueError, match="Azure endpoint is required"):
AzureCompletion(model="gpt-4", api_key="test-key")
llm._get_sync_client()
def test_azure_raises_error_when_api_key_missing():
"""Test that AzureCompletion raises ValueError when API key is missing"""
"""Credentials are validated lazily: construction succeeds, first
client build raises the descriptive error."""
from crewai.llms.providers.azure.completion import AzureCompletion
# Clear environment variables
with patch.dict(os.environ, {}, clear=True):
llm = AzureCompletion(
model="gpt-4", endpoint="https://test.openai.azure.com"
)
with pytest.raises(ValueError, match="Azure API key is required"):
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
llm._get_sync_client()
def test_azure_endpoint_configuration():