mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: defer native LLM client construction when credentials are missing
All native LLM providers built their SDK clients inside `@model_validator(mode="after")`, which required the API key at `LLM(...)` construction time. Instantiating an LLM at module scope (e.g. `chat_llm=LLM(model="openai/gpt-4o-mini")` on a `@crew` method) crashed during downstream crew-metadata extraction with a confusing `ImportError: Error importing native provider: 1 validation error...` before the process env vars were ever consulted. Wrap eager client construction in a try/except in each provider and add `_get_sync_client` / `_get_async_client` methods that build on first use. OpenAI call sites are routed through the lazy getters so calls made without eager construction still work. The descriptive "X_API_KEY is required" errors are re-raised from the lazy path at first real call. Update two Azure tests that asserted the old eager-error contract to assert the new lazy-error contract.
This commit is contained in:
@@ -189,16 +189,37 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> AnthropicCompletion:
|
||||
self._client = Anthropic(**self._get_client_params())
|
||||
"""Eagerly build clients when the API key is available, otherwise
|
||||
defer so ``LLM(model="anthropic/...")`` can be constructed at module
|
||||
import time even before deployment env vars are set.
|
||||
"""
|
||||
try:
|
||||
self._client = self._build_sync_client()
|
||||
self._async_client = self._build_async_client()
|
||||
except ValueError:
|
||||
pass
|
||||
return self
|
||||
|
||||
def _build_sync_client(self) -> Any:
|
||||
return Anthropic(**self._get_client_params())
|
||||
|
||||
def _build_async_client(self) -> Any:
|
||||
async_client_params = self._get_client_params()
|
||||
if self.interceptor:
|
||||
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_params["http_client"] = async_http_client
|
||||
return AsyncAnthropic(**async_client_params)
|
||||
|
||||
self._async_client = AsyncAnthropic(**async_client_params)
|
||||
return self
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_sync_client()
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
self._async_client = self._build_async_client()
|
||||
return self._async_client
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Anthropic-specific fields."""
|
||||
@@ -790,11 +811,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = self._client.beta.messages.create(
|
||||
response = self._get_sync_client().beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = self._client.messages.create(**params)
|
||||
response = self._get_sync_client().messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -942,9 +963,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
stream_context = (
|
||||
self._client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||
self._get_sync_client().beta.messages.stream(
|
||||
**stream_params, extra_body=extra_body
|
||||
)
|
||||
if betas
|
||||
else self._client.messages.stream(**stream_params)
|
||||
else self._get_sync_client().messages.stream(**stream_params)
|
||||
)
|
||||
with stream_context as stream:
|
||||
response_id = None
|
||||
@@ -1223,7 +1246,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self._client.messages.create(**follow_up_params)
|
||||
final_response: Message = self._get_sync_client().messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
@@ -1319,11 +1344,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = await self._async_client.beta.messages.create(
|
||||
response = await self._get_async_client().beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = await self._async_client.messages.create(**params)
|
||||
response = await self._get_async_client().messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -1457,11 +1482,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
stream_context = (
|
||||
self._async_client.beta.messages.stream(
|
||||
self._get_async_client().beta.messages.stream(
|
||||
**stream_params, extra_body=extra_body
|
||||
)
|
||||
if betas
|
||||
else self._async_client.messages.stream(**stream_params)
|
||||
else self._get_async_client().messages.stream(**stream_params)
|
||||
)
|
||||
async with stream_context as stream:
|
||||
response_id = None
|
||||
@@ -1626,7 +1651,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
try:
|
||||
final_response: Message = await self._async_client.messages.create(
|
||||
final_response: Message = await self._get_async_client().messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
@@ -1754,8 +1779,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
from crewai_files.uploaders.anthropic import AnthropicFileUploader
|
||||
|
||||
return AnthropicFileUploader(
|
||||
client=self._client,
|
||||
async_client=self._async_client,
|
||||
client=self._get_sync_client(),
|
||||
async_client=self._get_async_client(),
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -116,43 +116,72 @@ class AzureCompletion(BaseLLM):
|
||||
data.get("api_version") or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
)
|
||||
|
||||
if not data["api_key"]:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
if not data["endpoint"]:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
# Credentials and endpoint are validated lazily in `_init_clients`
|
||||
# so the LLM can be constructed before deployment env vars are set.
|
||||
model = data.get("model", "")
|
||||
data["endpoint"] = AzureCompletion._validate_and_fix_endpoint(
|
||||
data["endpoint"], model
|
||||
)
|
||||
if data["endpoint"]:
|
||||
data["endpoint"] = AzureCompletion._validate_and_fix_endpoint(
|
||||
data["endpoint"], model
|
||||
)
|
||||
parsed = urlparse(data["endpoint"])
|
||||
hostname = parsed.hostname or ""
|
||||
data["is_azure_openai_endpoint"] = (
|
||||
hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")
|
||||
) and "/openai/deployments/" in data["endpoint"]
|
||||
else:
|
||||
data["is_azure_openai_endpoint"] = False
|
||||
data["is_openai_model"] = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
parsed = urlparse(data["endpoint"])
|
||||
hostname = parsed.hostname or ""
|
||||
data["is_azure_openai_endpoint"] = (
|
||||
hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")
|
||||
) and "/openai/deployments/" in data["endpoint"]
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> AzureCompletion:
|
||||
"""Eagerly build clients when credentials are available, otherwise
|
||||
defer so ``LLM(model="azure/...")`` can be constructed at module
|
||||
import time even before deployment env vars are set.
|
||||
"""
|
||||
try:
|
||||
self._client = self._build_sync_client()
|
||||
self._async_client = self._build_async_client()
|
||||
except ValueError:
|
||||
pass
|
||||
return self
|
||||
|
||||
def _build_sync_client(self) -> Any:
|
||||
return ChatCompletionsClient(**self._make_client_kwargs())
|
||||
|
||||
def _build_async_client(self) -> Any:
|
||||
return AsyncChatCompletionsClient(**self._make_client_kwargs())
|
||||
|
||||
def _make_client_kwargs(self) -> dict[str, Any]:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key is required.")
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment "
|
||||
"variable or pass api_key parameter."
|
||||
)
|
||||
if not self.endpoint:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment "
|
||||
"variable or pass endpoint parameter."
|
||||
)
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
}
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
return client_kwargs
|
||||
|
||||
self._client = ChatCompletionsClient(**client_kwargs)
|
||||
self._async_client = AsyncChatCompletionsClient(**client_kwargs)
|
||||
return self
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_sync_client()
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
self._async_client = self._build_async_client()
|
||||
return self._async_client
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Azure-specific fields."""
|
||||
@@ -713,8 +742,7 @@ class AzureCompletion(BaseLLM):
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = self._client.complete(**params)
|
||||
response: ChatCompletions = self._get_sync_client().complete(**params)
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
@@ -913,7 +941,7 @@ class AzureCompletion(BaseLLM):
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
usage_data: dict[str, Any] | None = None
|
||||
for update in self._client.complete(**params):
|
||||
for update in self._get_sync_client().complete(**params):
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.usage:
|
||||
usage = update.usage
|
||||
@@ -953,8 +981,9 @@ class AzureCompletion(BaseLLM):
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion asynchronously."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = await self._async_client.complete(**params)
|
||||
response: ChatCompletions = await self._get_async_client().complete(
|
||||
**params
|
||||
)
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
@@ -980,7 +1009,7 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
usage_data: dict[str, Any] | None = None
|
||||
|
||||
stream = await self._async_client.complete(**params)
|
||||
stream = await self._get_async_client().complete(**params)
|
||||
async for update in stream:
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if hasattr(update, "usage") and update.usage:
|
||||
@@ -1105,8 +1134,8 @@ class AzureCompletion(BaseLLM):
|
||||
This ensures proper cleanup of the underlying aiohttp session
|
||||
to avoid unclosed connector warnings.
|
||||
"""
|
||||
if hasattr(self._async_client, "close"):
|
||||
await self._async_client.close()
|
||||
if hasattr(self._get_async_client(), "close"):
|
||||
await self._get_async_client().close()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
|
||||
@@ -303,6 +303,18 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> BedrockCompletion:
|
||||
"""Eagerly build the sync client when AWS credentials resolve,
|
||||
otherwise defer so ``LLM(model="bedrock/...")`` can be constructed
|
||||
at module import time even before deployment env vars are set.
|
||||
"""
|
||||
try:
|
||||
self._client = self._build_sync_client()
|
||||
except Exception as e:
|
||||
logging.debug("Deferring Bedrock client construction: %s", e)
|
||||
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
|
||||
return self
|
||||
|
||||
def _build_sync_client(self) -> Any:
|
||||
config = Config(
|
||||
read_timeout=300,
|
||||
retries={"max_attempts": 3, "mode": "adaptive"},
|
||||
@@ -314,9 +326,17 @@ class BedrockCompletion(BaseLLM):
|
||||
aws_session_token=self.aws_session_token,
|
||||
region_name=self.region_name,
|
||||
)
|
||||
self._client = session.client("bedrock-runtime", config=config)
|
||||
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
|
||||
return self
|
||||
return session.client("bedrock-runtime", config=config)
|
||||
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_sync_client()
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
"""Async client is set up separately by ``_ensure_async_client``
|
||||
using ``aiobotocore`` inside an exit stack."""
|
||||
return self._async_client
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Bedrock-specific fields."""
|
||||
@@ -656,7 +676,7 @@ class BedrockCompletion(BaseLLM):
|
||||
raise ValueError(f"Invalid message format at index {i}")
|
||||
|
||||
# Call Bedrock Converse API with proper error handling
|
||||
response = self._client.converse(
|
||||
response = self._get_sync_client().converse(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
@@ -945,7 +965,7 @@ class BedrockCompletion(BaseLLM):
|
||||
usage_data: dict[str, Any] | None = None
|
||||
|
||||
try:
|
||||
response = self._client.converse_stream(
|
||||
response = self._get_sync_client().converse_stream(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
@@ -1174,7 +1194,7 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
self._async_client = client
|
||||
self._async_client_initialized = True
|
||||
return self._async_client
|
||||
return self._get_async_client()
|
||||
|
||||
async def _ahandle_converse(
|
||||
self,
|
||||
|
||||
@@ -118,9 +118,25 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_client(self) -> GeminiCompletion:
|
||||
self._client = self._initialize_client(self.use_vertexai)
|
||||
"""Eagerly build the client when credentials resolve, otherwise defer
|
||||
so ``LLM(model="gemini/...")`` can be constructed at module import time
|
||||
even before deployment env vars are set.
|
||||
"""
|
||||
try:
|
||||
self._client = self._initialize_client(self.use_vertexai)
|
||||
except ValueError:
|
||||
pass
|
||||
return self
|
||||
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._initialize_client(self.use_vertexai)
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
"""Gemini uses a single client for both sync and async calls."""
|
||||
return self._get_sync_client()
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Gemini/Vertex-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
@@ -228,8 +244,8 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
if (
|
||||
hasattr(self, "client")
|
||||
and hasattr(self._client, "vertexai")
|
||||
and self._client.vertexai
|
||||
and hasattr(self._get_sync_client(), "vertexai")
|
||||
and self._get_sync_client().vertexai
|
||||
):
|
||||
# Vertex AI configuration
|
||||
params.update(
|
||||
@@ -1112,7 +1128,7 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = self._client.models.generate_content(
|
||||
response = self._get_sync_client().models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1153,7 +1169,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
for chunk in self._client.models.generate_content_stream(
|
||||
for chunk in self._get_sync_client().models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1191,7 +1207,7 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = await self._client.aio.models.generate_content(
|
||||
response = await self._get_sync_client().aio.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1232,7 +1248,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
stream = await self._client.aio.models.generate_content_stream(
|
||||
stream = await self._get_sync_client().aio.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1439,6 +1455,6 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
from crewai_files.uploaders.gemini import GeminiFileUploader
|
||||
|
||||
return GeminiFileUploader(client=self._client)
|
||||
return GeminiFileUploader(client=self._get_sync_client())
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -253,22 +253,40 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> OpenAICompletion:
|
||||
"""Eagerly build clients when the API key is available, otherwise
|
||||
defer so ``LLM(model="openai/...")`` can be constructed at module
|
||||
import time even before deployment env vars are set.
|
||||
"""
|
||||
try:
|
||||
self._client = self._build_sync_client()
|
||||
self._async_client = self._build_async_client()
|
||||
except ValueError:
|
||||
pass
|
||||
return self
|
||||
|
||||
def _build_sync_client(self) -> Any:
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_config["http_client"] = http_client
|
||||
client_config["http_client"] = httpx.Client(transport=transport)
|
||||
return OpenAI(**client_config)
|
||||
|
||||
self._client = OpenAI(**client_config)
|
||||
|
||||
async_client_config = self._get_client_params()
|
||||
def _build_async_client(self) -> Any:
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
async_transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_config["http_client"] = async_http_client
|
||||
transport = AsyncHTTPTransport(interceptor=self.interceptor)
|
||||
client_config["http_client"] = httpx.AsyncClient(transport=transport)
|
||||
return AsyncOpenAI(**client_config)
|
||||
|
||||
self._async_client = AsyncOpenAI(**async_client_config)
|
||||
return self
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_sync_client()
|
||||
return self._client
|
||||
|
||||
def _get_async_client(self) -> Any:
|
||||
if self._async_client is None:
|
||||
self._async_client = self._build_async_client()
|
||||
return self._async_client
|
||||
|
||||
@property
|
||||
def last_response_id(self) -> str | None:
|
||||
@@ -797,7 +815,7 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str | ResponsesAPIResult | Any:
|
||||
"""Handle non-streaming Responses API call."""
|
||||
try:
|
||||
response: Response = self._client.responses.create(**params)
|
||||
response: Response = self._get_sync_client().responses.create(**params)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
@@ -933,7 +951,9 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str | ResponsesAPIResult | Any:
|
||||
"""Handle async non-streaming Responses API call."""
|
||||
try:
|
||||
response: Response = await self._async_client.responses.create(**params)
|
||||
response: Response = await self._get_async_client().responses.create(
|
||||
**params
|
||||
)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
@@ -1069,7 +1089,7 @@ class OpenAICompletion(BaseLLM):
|
||||
final_response: Response | None = None
|
||||
usage: dict[str, Any] | None = None
|
||||
|
||||
stream = self._client.responses.create(**params)
|
||||
stream = self._get_sync_client().responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
for event in stream:
|
||||
@@ -1197,7 +1217,7 @@ class OpenAICompletion(BaseLLM):
|
||||
final_response: Response | None = None
|
||||
usage: dict[str, Any] | None = None
|
||||
|
||||
stream = await self._async_client.responses.create(**params)
|
||||
stream = await self._get_async_client().responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
async for event in stream:
|
||||
@@ -1591,7 +1611,7 @@ class OpenAICompletion(BaseLLM):
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = self._client.beta.chat.completions.parse(
|
||||
parsed_response = self._get_sync_client().beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -1615,7 +1635,9 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = self._client.chat.completions.create(**params)
|
||||
response: ChatCompletion = self._get_sync_client().chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
@@ -1842,7 +1864,7 @@ class OpenAICompletion(BaseLLM):
|
||||
}
|
||||
|
||||
stream: ChatCompletionStream[BaseModel]
|
||||
with self._client.beta.chat.completions.stream(
|
||||
with self._get_sync_client().beta.chat.completions.stream(
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
@@ -1879,7 +1901,7 @@ class OpenAICompletion(BaseLLM):
|
||||
return ""
|
||||
|
||||
completion_stream: Stream[ChatCompletionChunk] = (
|
||||
self._client.chat.completions.create(**params)
|
||||
self._get_sync_client().chat.completions.create(**params)
|
||||
)
|
||||
|
||||
usage_data: dict[str, Any] | None = None
|
||||
@@ -1976,9 +1998,11 @@ class OpenAICompletion(BaseLLM):
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = await self._async_client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
parsed_response = (
|
||||
await self._get_async_client().beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
)
|
||||
math_reasoning = parsed_response.choices[0].message
|
||||
|
||||
@@ -2000,8 +2024,8 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = await self._async_client.chat.completions.create(
|
||||
**params
|
||||
response: ChatCompletion = (
|
||||
await self._get_async_client().chat.completions.create(**params)
|
||||
)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
@@ -2127,7 +2151,7 @@ class OpenAICompletion(BaseLLM):
|
||||
if response_model:
|
||||
completion_stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
] = await self._get_async_client().chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
usage_data: dict[str, Any] | None = None
|
||||
@@ -2183,7 +2207,7 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
] = await self._get_async_client().chat.completions.create(**params)
|
||||
|
||||
usage_data = None
|
||||
|
||||
@@ -2379,8 +2403,8 @@ class OpenAICompletion(BaseLLM):
|
||||
from crewai_files.uploaders.openai import OpenAIFileUploader
|
||||
|
||||
return OpenAIFileUploader(
|
||||
client=self._client,
|
||||
async_client=self._async_client,
|
||||
client=self._get_sync_client(),
|
||||
async_client=self._get_async_client(),
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -378,23 +378,27 @@ def test_azure_completion_with_tools():
|
||||
|
||||
|
||||
def test_azure_raises_error_when_endpoint_missing():
|
||||
"""Test that AzureCompletion raises ValueError when endpoint is missing"""
|
||||
"""Credentials are validated lazily: construction succeeds, first
|
||||
client build raises the descriptive error."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
# Clear environment variables
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
llm = AzureCompletion(model="gpt-4", api_key="test-key")
|
||||
with pytest.raises(ValueError, match="Azure endpoint is required"):
|
||||
AzureCompletion(model="gpt-4", api_key="test-key")
|
||||
llm._get_sync_client()
|
||||
|
||||
|
||||
def test_azure_raises_error_when_api_key_missing():
|
||||
"""Test that AzureCompletion raises ValueError when API key is missing"""
|
||||
"""Credentials are validated lazily: construction succeeds, first
|
||||
client build raises the descriptive error."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
# Clear environment variables
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
llm = AzureCompletion(
|
||||
model="gpt-4", endpoint="https://test.openai.azure.com"
|
||||
)
|
||||
with pytest.raises(ValueError, match="Azure API key is required"):
|
||||
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||
llm._get_sync_client()
|
||||
|
||||
|
||||
def test_azure_endpoint_configuration():
|
||||
|
||||
Reference in New Issue
Block a user