feat: add Azure AD token-based authentication support

- Add azure_ad_token_provider parameter for token-based auth
- Add credential parameter for TokenCredential instances
- Create _TokenProviderCredential wrapper class for token providers
- Update authentication logic with priority: credential > token_provider > api_key
- Add support for base_url parameter as alternative to endpoint
- Update error message to reflect new authentication options
- Add comprehensive tests for all new authentication methods

Fixes #4018

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-12-02 10:48:13 +00:00
parent 20704742e2
commit 347381be57
2 changed files with 279 additions and 12 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
from collections.abc import Callable
import json
import logging
import os
import time
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
@@ -35,7 +37,9 @@ try:
StreamingChatCompletionsUpdate,
)
from azure.core.credentials import (
AccessToken,
AzureKeyCredential,
TokenCredential,
)
from azure.core.exceptions import (
HttpResponseError,
@@ -50,6 +54,41 @@ except ImportError:
) from None
class _TokenProviderCredential(TokenCredential):
"""Wrapper class to convert an azure_ad_token_provider callable into a TokenCredential.
This allows users to pass a token provider function (like the one returned by
azure.identity.get_bearer_token_provider) to the Azure AI Inference client.
"""
def __init__(self, provider: Callable[..., Any]):
"""Initialize with a token provider callable.
Args:
provider: A callable that returns an access token. This is typically
the result of azure.identity.get_bearer_token_provider().
"""
self._provider = provider
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""Get an access token from the provider.
Args:
*scopes: The scopes for the token (ignored, as the provider handles this).
**kwargs: Additional keyword arguments (ignored).
Returns:
An AccessToken instance.
"""
raw = self._provider()
if isinstance(raw, AccessToken):
return raw
# If it's a bare string, wrap it with a default expiry of 1 hour
return AccessToken(str(raw), int(time.time()) + 3600)
class AzureCompletion(BaseLLM):
"""Azure AI Inference native completion implementation.
@@ -73,6 +112,8 @@ class AzureCompletion(BaseLLM):
stop: list[str] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[Any, Any] | None = None,
azure_ad_token_provider: Callable[..., Any] | None = None,
credential: TokenCredential | None = None,
**kwargs: Any,
):
"""Initialize Azure AI Inference chat completion client.
@@ -92,6 +133,13 @@ class AzureCompletion(BaseLLM):
stop: Stop sequences
stream: Enable streaming responses
interceptor: HTTP interceptor (not yet supported for Azure).
azure_ad_token_provider: A callable that returns an Azure AD token.
This is typically the result of azure.identity.get_bearer_token_provider().
Use this for Azure AD token-based authentication instead of API keys.
credential: An Azure TokenCredential instance for authentication.
This can be any credential from azure.identity (e.g., DefaultAzureCredential,
ManagedIdentityCredential). Takes precedence over azure_ad_token_provider
and api_key.
**kwargs: Additional parameters
"""
if interceptor is not None:
@@ -107,6 +155,7 @@ class AzureCompletion(BaseLLM):
self.api_key = api_key or os.getenv("AZURE_API_KEY")
self.endpoint = (
endpoint
or kwargs.get("base_url")
or os.getenv("AZURE_ENDPOINT")
or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE")
@@ -115,31 +164,45 @@ class AzureCompletion(BaseLLM):
self.timeout = timeout
self.max_retries = max_retries
if not self.api_key:
raise ValueError(
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
)
if not self.endpoint:
raise ValueError(
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
)
# Determine the credential to use (priority: credential > azure_ad_token_provider > api_key)
chosen_credential: TokenCredential | AzureKeyCredential | None = None
if credential is not None:
chosen_credential = credential
elif azure_ad_token_provider is not None:
chosen_credential = _TokenProviderCredential(azure_ad_token_provider)
elif self.api_key:
chosen_credential = AzureKeyCredential(self.api_key)
if chosen_credential is None:
raise ValueError(
"Azure authentication is required. Provide one of: "
"api_key (or set AZURE_API_KEY environment variable), "
"azure_ad_token_provider (callable from azure.identity.get_bearer_token_provider), "
"or credential (TokenCredential instance from azure.identity)."
)
# Validate and potentially fix Azure OpenAI endpoint URL
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
# Build client kwargs
client_kwargs = {
client_kwargs: dict[str, Any] = {
"endpoint": self.endpoint,
"credential": AzureKeyCredential(self.api_key),
"credential": chosen_credential,
}
# Add api_version if specified (primarily for Azure OpenAI endpoints)
if self.api_version:
client_kwargs["api_version"] = self.api_version
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
self.client = ChatCompletionsClient(**client_kwargs)
self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
self.async_client = AsyncChatCompletionsClient(**client_kwargs)
self.top_p = top_p
self.frequency_penalty = frequency_penalty