mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: add Azure AD token-based authentication support
- Add azure_ad_token_provider parameter for token-based auth - Add credential parameter for TokenCredential instances - Create _TokenProviderCredential wrapper class for token providers - Update authentication logic with priority: credential > token_provider > api_key - Add support for base_url parameter as alternative to endpoint - Update error message to reflect new authentication options - Add comprehensive tests for all new authentication methods Fixes #4018 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -35,7 +37,9 @@ try:
|
|||||||
StreamingChatCompletionsUpdate,
|
StreamingChatCompletionsUpdate,
|
||||||
)
|
)
|
||||||
from azure.core.credentials import (
|
from azure.core.credentials import (
|
||||||
|
AccessToken,
|
||||||
AzureKeyCredential,
|
AzureKeyCredential,
|
||||||
|
TokenCredential,
|
||||||
)
|
)
|
||||||
from azure.core.exceptions import (
|
from azure.core.exceptions import (
|
||||||
HttpResponseError,
|
HttpResponseError,
|
||||||
@@ -50,6 +54,41 @@ except ImportError:
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
class _TokenProviderCredential(TokenCredential):
|
||||||
|
"""Wrapper class to convert an azure_ad_token_provider callable into a TokenCredential.
|
||||||
|
|
||||||
|
This allows users to pass a token provider function (like the one returned by
|
||||||
|
azure.identity.get_bearer_token_provider) to the Azure AI Inference client.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, provider: Callable[..., Any]):
|
||||||
|
"""Initialize with a token provider callable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: A callable that returns an access token. This is typically
|
||||||
|
the result of azure.identity.get_bearer_token_provider().
|
||||||
|
"""
|
||||||
|
self._provider = provider
|
||||||
|
|
||||||
|
def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
|
||||||
|
"""Get an access token from the provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*scopes: The scopes for the token (ignored, as the provider handles this).
|
||||||
|
**kwargs: Additional keyword arguments (ignored).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An AccessToken instance.
|
||||||
|
"""
|
||||||
|
raw = self._provider()
|
||||||
|
|
||||||
|
if isinstance(raw, AccessToken):
|
||||||
|
return raw
|
||||||
|
|
||||||
|
# If it's a bare string, wrap it with a default expiry of 1 hour
|
||||||
|
return AccessToken(str(raw), int(time.time()) + 3600)
|
||||||
|
|
||||||
|
|
||||||
class AzureCompletion(BaseLLM):
|
class AzureCompletion(BaseLLM):
|
||||||
"""Azure AI Inference native completion implementation.
|
"""Azure AI Inference native completion implementation.
|
||||||
|
|
||||||
@@ -73,6 +112,8 @@ class AzureCompletion(BaseLLM):
|
|||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
|
azure_ad_token_provider: Callable[..., Any] | None = None,
|
||||||
|
credential: TokenCredential | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Azure AI Inference chat completion client.
|
"""Initialize Azure AI Inference chat completion client.
|
||||||
@@ -92,6 +133,13 @@ class AzureCompletion(BaseLLM):
|
|||||||
stop: Stop sequences
|
stop: Stop sequences
|
||||||
stream: Enable streaming responses
|
stream: Enable streaming responses
|
||||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||||
|
azure_ad_token_provider: A callable that returns an Azure AD token.
|
||||||
|
This is typically the result of azure.identity.get_bearer_token_provider().
|
||||||
|
Use this for Azure AD token-based authentication instead of API keys.
|
||||||
|
credential: An Azure TokenCredential instance for authentication.
|
||||||
|
This can be any credential from azure.identity (e.g., DefaultAzureCredential,
|
||||||
|
ManagedIdentityCredential). Takes precedence over azure_ad_token_provider
|
||||||
|
and api_key.
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
if interceptor is not None:
|
||||||
@@ -107,6 +155,7 @@ class AzureCompletion(BaseLLM):
|
|||||||
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
||||||
self.endpoint = (
|
self.endpoint = (
|
||||||
endpoint
|
endpoint
|
||||||
|
or kwargs.get("base_url")
|
||||||
or os.getenv("AZURE_ENDPOINT")
|
or os.getenv("AZURE_ENDPOINT")
|
||||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||||
or os.getenv("AZURE_API_BASE")
|
or os.getenv("AZURE_API_BASE")
|
||||||
@@ -115,31 +164,45 @@ class AzureCompletion(BaseLLM):
|
|||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError(
|
|
||||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
|
||||||
)
|
|
||||||
if not self.endpoint:
|
if not self.endpoint:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Determine the credential to use (priority: credential > azure_ad_token_provider > api_key)
|
||||||
|
chosen_credential: TokenCredential | AzureKeyCredential | None = None
|
||||||
|
|
||||||
|
if credential is not None:
|
||||||
|
chosen_credential = credential
|
||||||
|
elif azure_ad_token_provider is not None:
|
||||||
|
chosen_credential = _TokenProviderCredential(azure_ad_token_provider)
|
||||||
|
elif self.api_key:
|
||||||
|
chosen_credential = AzureKeyCredential(self.api_key)
|
||||||
|
|
||||||
|
if chosen_credential is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Azure authentication is required. Provide one of: "
|
||||||
|
"api_key (or set AZURE_API_KEY environment variable), "
|
||||||
|
"azure_ad_token_provider (callable from azure.identity.get_bearer_token_provider), "
|
||||||
|
"or credential (TokenCredential instance from azure.identity)."
|
||||||
|
)
|
||||||
|
|
||||||
# Validate and potentially fix Azure OpenAI endpoint URL
|
# Validate and potentially fix Azure OpenAI endpoint URL
|
||||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
||||||
|
|
||||||
# Build client kwargs
|
# Build client kwargs
|
||||||
client_kwargs = {
|
client_kwargs: dict[str, Any] = {
|
||||||
"endpoint": self.endpoint,
|
"endpoint": self.endpoint,
|
||||||
"credential": AzureKeyCredential(self.api_key),
|
"credential": chosen_credential,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
||||||
if self.api_version:
|
if self.api_version:
|
||||||
client_kwargs["api_version"] = self.api_version
|
client_kwargs["api_version"] = self.api_version
|
||||||
|
|
||||||
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
self.client = ChatCompletionsClient(**client_kwargs)
|
||||||
|
|
||||||
self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
self.async_client = AsyncChatCompletionsClient(**client_kwargs)
|
||||||
|
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.frequency_penalty = frequency_penalty
|
self.frequency_penalty = frequency_penalty
|
||||||
|
|||||||
@@ -382,13 +382,13 @@ def test_azure_raises_error_when_endpoint_missing():
|
|||||||
AzureCompletion(model="gpt-4", api_key="test-key")
|
AzureCompletion(model="gpt-4", api_key="test-key")
|
||||||
|
|
||||||
|
|
||||||
def test_azure_raises_error_when_api_key_missing():
|
def test_azure_raises_error_when_no_auth_provided():
|
||||||
"""Test that AzureCompletion raises ValueError when API key is missing"""
|
"""Test that AzureCompletion raises ValueError when no authentication is provided"""
|
||||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
# Clear environment variables
|
# Clear environment variables
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
with pytest.raises(ValueError, match="Azure API key is required"):
|
with pytest.raises(ValueError, match="Azure authentication is required"):
|
||||||
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||||
|
|
||||||
|
|
||||||
@@ -1112,4 +1112,208 @@ def test_azure_completion_params_preparation_with_drop_params():
|
|||||||
messages = [{"role": "user", "content": "Hello"}]
|
messages = [{"role": "user", "content": "Hello"}]
|
||||||
params = llm._prepare_completion_params(messages)
|
params = llm._prepare_completion_params(messages)
|
||||||
|
|
||||||
assert params.get('stop') == None
|
assert params.get('stop') == None
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_ad_token_provider_authentication():
|
||||||
|
"""
|
||||||
|
Test that AzureCompletion can be initialized with azure_ad_token_provider
|
||||||
|
for Azure AD token-based authentication instead of API keys.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
|
# Mock token provider that returns a string token
|
||||||
|
def mock_token_provider():
|
||||||
|
return "mock-azure-ad-token"
|
||||||
|
|
||||||
|
# Clear environment variables to ensure no API key is used
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
completion = AzureCompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
endpoint="https://test.openai.azure.com",
|
||||||
|
azure_ad_token_provider=mock_token_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||||
|
assert completion.api_key is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_ad_token_provider_with_access_token():
|
||||||
|
"""
|
||||||
|
Test that azure_ad_token_provider works when it returns an AccessToken object.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion, _TokenProviderCredential
|
||||||
|
from azure.core.credentials import AccessToken
|
||||||
|
|
||||||
|
# Mock token provider that returns an AccessToken object
|
||||||
|
mock_access_token = AccessToken("mock-token-string", 1234567890)
|
||||||
|
|
||||||
|
def mock_token_provider():
|
||||||
|
return mock_access_token
|
||||||
|
|
||||||
|
# Test the _TokenProviderCredential wrapper
|
||||||
|
credential = _TokenProviderCredential(mock_token_provider)
|
||||||
|
token = credential.get_token("https://cognitiveservices.azure.com/.default")
|
||||||
|
|
||||||
|
assert token.token == "mock-token-string"
|
||||||
|
assert token.expires_on == 1234567890
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_ad_token_provider_with_string_token():
|
||||||
|
"""
|
||||||
|
Test that azure_ad_token_provider works when it returns a plain string token.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import _TokenProviderCredential
|
||||||
|
|
||||||
|
# Mock token provider that returns a plain string
|
||||||
|
def mock_token_provider():
|
||||||
|
return "plain-string-token"
|
||||||
|
|
||||||
|
credential = _TokenProviderCredential(mock_token_provider)
|
||||||
|
token = credential.get_token("https://cognitiveservices.azure.com/.default")
|
||||||
|
|
||||||
|
assert token.token == "plain-string-token"
|
||||||
|
# Should have a default expiry time (approximately 1 hour from now)
|
||||||
|
assert token.expires_on > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_credential_authentication():
|
||||||
|
"""
|
||||||
|
Test that AzureCompletion can be initialized with a TokenCredential instance.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
from azure.core.credentials import AccessToken, TokenCredential
|
||||||
|
|
||||||
|
# Create a mock TokenCredential
|
||||||
|
class MockTokenCredential(TokenCredential):
|
||||||
|
def get_token(self, *scopes, **kwargs):
|
||||||
|
return AccessToken("mock-credential-token", 1234567890)
|
||||||
|
|
||||||
|
mock_credential = MockTokenCredential()
|
||||||
|
|
||||||
|
# Clear environment variables to ensure no API key is used
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
completion = AzureCompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
endpoint="https://test.openai.azure.com",
|
||||||
|
credential=mock_credential
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||||
|
assert completion.api_key is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_credential_takes_precedence_over_api_key():
|
||||||
|
"""
|
||||||
|
Test that credential parameter takes precedence over api_key when both are provided.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
from azure.core.credentials import AccessToken, TokenCredential
|
||||||
|
|
||||||
|
class MockTokenCredential(TokenCredential):
|
||||||
|
def get_token(self, *scopes, **kwargs):
|
||||||
|
return AccessToken("credential-token", 1234567890)
|
||||||
|
|
||||||
|
mock_credential = MockTokenCredential()
|
||||||
|
|
||||||
|
# Provide both credential and api_key
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
completion = AzureCompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
endpoint="https://test.openai.azure.com",
|
||||||
|
api_key="should-not-be-used",
|
||||||
|
credential=mock_credential
|
||||||
|
)
|
||||||
|
|
||||||
|
# The completion should be created successfully with the credential
|
||||||
|
assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_ad_token_provider_takes_precedence_over_api_key():
|
||||||
|
"""
|
||||||
|
Test that azure_ad_token_provider takes precedence over api_key when both are provided.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
|
def mock_token_provider():
|
||||||
|
return "token-provider-token"
|
||||||
|
|
||||||
|
# Provide both azure_ad_token_provider and api_key
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
completion = AzureCompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
endpoint="https://test.openai.azure.com",
|
||||||
|
api_key="should-not-be-used",
|
||||||
|
azure_ad_token_provider=mock_token_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
# The completion should be created successfully with the token provider
|
||||||
|
assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_credential_takes_precedence_over_token_provider():
|
||||||
|
"""
|
||||||
|
Test that credential takes precedence over azure_ad_token_provider when both are provided.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
from azure.core.credentials import AccessToken, TokenCredential
|
||||||
|
|
||||||
|
class MockTokenCredential(TokenCredential):
|
||||||
|
def get_token(self, *scopes, **kwargs):
|
||||||
|
return AccessToken("credential-token", 1234567890)
|
||||||
|
|
||||||
|
mock_credential = MockTokenCredential()
|
||||||
|
|
||||||
|
def mock_token_provider():
|
||||||
|
return "token-provider-token"
|
||||||
|
|
||||||
|
# Provide both credential and azure_ad_token_provider
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
completion = AzureCompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
endpoint="https://test.openai.azure.com",
|
||||||
|
credential=mock_credential,
|
||||||
|
azure_ad_token_provider=mock_token_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
# The completion should be created successfully with the credential
|
||||||
|
assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_ad_token_provider_via_llm_factory():
|
||||||
|
"""
|
||||||
|
Test that azure_ad_token_provider can be passed through the LLM factory class.
|
||||||
|
"""
|
||||||
|
def mock_token_provider():
|
||||||
|
return "mock-token"
|
||||||
|
|
||||||
|
# Clear environment variables
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||||
|
}, clear=True):
|
||||||
|
llm = LLM(
|
||||||
|
model="azure/gpt-4",
|
||||||
|
azure_ad_token_provider=mock_token_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
assert isinstance(llm, AzureCompletion)
|
||||||
|
assert llm.api_key is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_base_url_parameter_for_endpoint():
|
||||||
|
"""
|
||||||
|
Test that base_url parameter can be used as an alternative to endpoint.
|
||||||
|
This is useful for users migrating from other providers.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
|
# Clear environment variables
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
completion = AzureCompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
base_url="https://test.openai.azure.com",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||||
|
|||||||
Reference in New Issue
Block a user