From aded8ef74a11f69233d9fdf36c9846f232ae71d4 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 11 Dec 2025 08:56:46 +0000 Subject: [PATCH] feat: Add Azure AD token authentication support for Azure provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds support for Azure AD token authentication (Microsoft Entra ID) to the Azure AI Inference native provider, addressing issue #4069. Changes: - Add credential parameter for passing TokenCredential directly - Add azure_ad_token parameter and AZURE_AD_TOKEN env var support - Add use_default_credential flag for DefaultAzureCredential - Add _StaticTokenCredential class for wrapping static tokens - Add _select_credential method with clear priority order - Update error messages to reflect all authentication options - Add comprehensive tests for all new authentication methods Authentication Priority: 1. credential parameter (explicit TokenCredential) 2. azure_ad_token parameter or AZURE_AD_TOKEN env var 3. api_key parameter or AZURE_API_KEY env var 4. use_default_credential=True (DefaultAzureCredential) Fixes #4069 Co-Authored-By: João --- .../crewai/llms/providers/azure/completion.py | 127 +++++++- lib/crewai/tests/llms/azure/test_azure.py | 289 +++++++++++++++++- 2 files changed, 406 insertions(+), 10 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index 687dee9c6..6bd6c9df5 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -3,6 +3,7 @@ from __future__ import annotations import json import logging import os +import time from typing import TYPE_CHECKING, Any, TypedDict from pydantic import BaseModel @@ -17,6 +18,8 @@ from crewai.utilities.types import LLMMessage if TYPE_CHECKING: + from azure.core.credentials import AccessToken, TokenCredential + from crewai.llms.hooks.base import BaseInterceptor @@ -51,6 +54,39 @@ except ImportError: ) from None +class _StaticTokenCredential: + """A simple TokenCredential implementation for static Azure AD tokens. + + This class wraps a static token string and provides it as a TokenCredential + that can be used with Azure SDK clients. The token is assumed to be valid + and the user is responsible for token rotation. + """ + + def __init__(self, token: str) -> None: + """Initialize with a static token. + + Args: + token: The Azure AD bearer token string. + """ + self._token = token + + def get_token( + self, *scopes: str, **kwargs: Any + ) -> AccessToken: + """Get the static token as an AccessToken. + + Args: + *scopes: Token scopes (ignored for static tokens). + **kwargs: Additional arguments (ignored). + + Returns: + AccessToken with the static token and a far-future expiry. + """ + from azure.core.credentials import AccessToken + + return AccessToken(self._token, int(time.time()) + 3600) + + class AzureCompletionParams(TypedDict, total=False): """Type definition for Azure chat completion parameters.""" @@ -92,6 +128,9 @@ class AzureCompletion(BaseLLM): stop: list[str] | None = None, stream: bool = False, interceptor: BaseInterceptor[Any, Any] | None = None, + credential: TokenCredential | None = None, + azure_ad_token: str | None = None, + use_default_credential: bool = False, **kwargs: Any, ): """Initialize Azure AI Inference chat completion client. @@ -111,7 +150,36 @@ class AzureCompletion(BaseLLM): stop: Stop sequences stream: Enable streaming responses interceptor: HTTP interceptor (not yet supported for Azure). + credential: Azure TokenCredential for Azure AD authentication (e.g., + DefaultAzureCredential, ManagedIdentityCredential). Takes precedence + over other authentication methods. + azure_ad_token: Static Azure AD token string (defaults to AZURE_AD_TOKEN + env var). Use this for scenarios where you have a pre-fetched token. + use_default_credential: If True, automatically use DefaultAzureCredential + for Azure AD authentication. Requires azure-identity package. **kwargs: Additional parameters + + Authentication Priority: + 1. credential parameter (explicit TokenCredential) + 2. azure_ad_token parameter or AZURE_AD_TOKEN env var + 3. api_key parameter or AZURE_API_KEY env var + 4. use_default_credential=True (DefaultAzureCredential) + + Example: + # Using API key (existing behavior) + llm = LLM(model="azure/gpt-4", api_key="...", endpoint="...") + + # Using Azure AD token from environment + os.environ["AZURE_AD_TOKEN"] = token_provider() + llm = LLM(model="azure/gpt-4", endpoint="...") + + # Using DefaultAzureCredential (Managed Identity, Azure CLI, etc.) + llm = LLM(model="azure/gpt-4", endpoint="...", use_default_credential=True) + + # Using explicit TokenCredential + from azure.identity import ManagedIdentityCredential + llm = LLM(model="azure/gpt-4", endpoint="...", + credential=ManagedIdentityCredential()) """ if interceptor is not None: raise NotImplementedError( @@ -124,6 +192,9 @@ class AzureCompletion(BaseLLM): ) self.api_key = api_key or os.getenv("AZURE_API_KEY") + self.azure_ad_token = azure_ad_token or os.getenv("AZURE_AD_TOKEN") + self._explicit_credential = credential + self.use_default_credential = use_default_credential self.endpoint = ( endpoint or os.getenv("AZURE_ENDPOINT") @@ -134,10 +205,6 @@ class AzureCompletion(BaseLLM): self.timeout = timeout self.max_retries = max_retries - if not self.api_key: - raise ValueError( - "Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter." - ) if not self.endpoint: raise ValueError( "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." @@ -146,19 +213,22 @@ class AzureCompletion(BaseLLM): # Validate and potentially fix Azure OpenAI endpoint URL self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model) + # Select credential based on priority + selected_credential = self._select_credential() + # Build client kwargs - client_kwargs = { + client_kwargs: dict[str, Any] = { "endpoint": self.endpoint, - "credential": AzureKeyCredential(self.api_key), + "credential": selected_credential, } # Add api_version if specified (primarily for Azure OpenAI endpoints) if self.api_version: client_kwargs["api_version"] = self.api_version - self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type] + self.client = ChatCompletionsClient(**client_kwargs) - self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type] + self.async_client = AsyncChatCompletionsClient(**client_kwargs) self.top_p = top_p self.frequency_penalty = frequency_penalty @@ -175,6 +245,47 @@ class AzureCompletion(BaseLLM): and "/openai/deployments/" in self.endpoint ) + def _select_credential(self) -> AzureKeyCredential | TokenCredential: + """Select the appropriate credential based on configuration priority. + + Priority order: + 1. Explicit credential parameter (TokenCredential) + 2. azure_ad_token parameter or AZURE_AD_TOKEN env var + 3. api_key parameter or AZURE_API_KEY env var + 4. use_default_credential=True (DefaultAzureCredential) + + Returns: + The selected credential for Azure authentication. + + Raises: + ValueError: If no valid credentials are configured. + """ + if self._explicit_credential is not None: + return self._explicit_credential + + if self.azure_ad_token: + return _StaticTokenCredential(self.azure_ad_token) + + if self.api_key: + return AzureKeyCredential(self.api_key) + + if self.use_default_credential: + try: + from azure.identity import DefaultAzureCredential + + return DefaultAzureCredential() + except ImportError: + raise ImportError( + "azure-identity package is required for use_default_credential=True. " + 'Install it with: uv add "azure-identity"' + ) from None + + raise ValueError( + "Azure credentials are required. Provide one of: " + "api_key / AZURE_API_KEY, azure_ad_token / AZURE_AD_TOKEN, " + "a TokenCredential via 'credential', or set use_default_credential=True." + ) + @staticmethod def _validate_and_fix_endpoint(endpoint: str, model: str) -> str: """Validate and fix Azure endpoint URL format. diff --git a/lib/crewai/tests/llms/azure/test_azure.py b/lib/crewai/tests/llms/azure/test_azure.py index 6c6ee5271..c69d78554 100644 --- a/lib/crewai/tests/llms/azure/test_azure.py +++ b/lib/crewai/tests/llms/azure/test_azure.py @@ -389,12 +389,12 @@ def test_azure_raises_error_when_endpoint_missing(): def test_azure_raises_error_when_api_key_missing(): - """Test that AzureCompletion raises ValueError when API key is missing""" + """Test that AzureCompletion raises ValueError when no credentials are provided""" from crewai.llms.providers.azure.completion import AzureCompletion # Clear environment variables with patch.dict(os.environ, {}, clear=True): - with pytest.raises(ValueError, match="Azure API key is required"): + with pytest.raises(ValueError, match="Azure credentials are required"): AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com") @@ -1127,3 +1127,288 @@ def test_azure_streaming_returns_usage_metrics(): assert result.token_usage.prompt_tokens > 0 assert result.token_usage.completion_tokens > 0 assert result.token_usage.successful_requests >= 1 + + +def test_azure_ad_token_authentication(): + """ + Test that Azure AD token authentication works via AZURE_AD_TOKEN env var. + """ + from crewai.llms.providers.azure.completion import AzureCompletion, _StaticTokenCredential + + with patch.dict(os.environ, { + "AZURE_AD_TOKEN": "test-ad-token", + "AZURE_ENDPOINT": "https://test.openai.azure.com" + }, clear=True): + llm = LLM(model="azure/gpt-4") + + assert isinstance(llm, AzureCompletion) + assert llm.azure_ad_token == "test-ad-token" + assert llm.api_key is None + + +def test_azure_ad_token_parameter(): + """ + Test that azure_ad_token parameter works for Azure AD authentication. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + + llm = LLM( + model="azure/gpt-4", + azure_ad_token="my-ad-token", + endpoint="https://test.openai.azure.com" + ) + + assert isinstance(llm, AzureCompletion) + assert llm.azure_ad_token == "my-ad-token" + + +def test_azure_credential_parameter(): + """ + Test that credential parameter works for passing TokenCredential directly. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + + class MockTokenCredential: + def get_token(self, *scopes, **kwargs): + from azure.core.credentials import AccessToken + return AccessToken("mock-token", 9999999999) + + mock_credential = MockTokenCredential() + + llm = LLM( + model="azure/gpt-4", + credential=mock_credential, + endpoint="https://test.openai.azure.com" + ) + + assert isinstance(llm, AzureCompletion) + assert llm._explicit_credential is mock_credential + + +def test_azure_use_default_credential(): + """ + Test that use_default_credential=True uses DefaultAzureCredential. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + + try: + from azure.identity import DefaultAzureCredential + azure_identity_available = True + except ImportError: + azure_identity_available = False + + if azure_identity_available: + with patch('azure.identity.DefaultAzureCredential') as mock_default_cred: + mock_default_cred.return_value = MagicMock() + + with patch.dict(os.environ, { + "AZURE_ENDPOINT": "https://test.openai.azure.com" + }, clear=True): + llm = LLM( + model="azure/gpt-4", + use_default_credential=True + ) + + assert isinstance(llm, AzureCompletion) + assert llm.use_default_credential is True + mock_default_cred.assert_called_once() + else: + with patch.dict(os.environ, { + "AZURE_ENDPOINT": "https://test.openai.azure.com" + }, clear=True): + with pytest.raises(ImportError, match="azure-identity package is required"): + LLM( + model="azure/gpt-4", + use_default_credential=True + ) + + +def test_azure_credential_priority_explicit_credential_first(): + """ + Test that explicit credential takes priority over other auth methods. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + + class MockTokenCredential: + def get_token(self, *scopes, **kwargs): + from azure.core.credentials import AccessToken + return AccessToken("mock-token", 9999999999) + + mock_credential = MockTokenCredential() + + with patch.dict(os.environ, { + "AZURE_API_KEY": "test-key", + "AZURE_AD_TOKEN": "test-ad-token", + "AZURE_ENDPOINT": "https://test.openai.azure.com" + }): + llm = LLM( + model="azure/gpt-4", + credential=mock_credential, + api_key="another-key", + azure_ad_token="another-token" + ) + + assert isinstance(llm, AzureCompletion) + assert llm._explicit_credential is mock_credential + + +def test_azure_credential_priority_ad_token_over_api_key(): + """ + Test that azure_ad_token takes priority over api_key. + """ + from crewai.llms.providers.azure.completion import AzureCompletion, _StaticTokenCredential + + with patch.dict(os.environ, { + "AZURE_ENDPOINT": "https://test.openai.azure.com" + }, clear=True): + llm = LLM( + model="azure/gpt-4", + azure_ad_token="my-ad-token", + api_key="my-api-key" + ) + + assert isinstance(llm, AzureCompletion) + assert llm.azure_ad_token == "my-ad-token" + assert llm.api_key == "my-api-key" + + +def test_azure_raises_error_when_no_credentials(): + """ + Test that AzureCompletion raises ValueError when no credentials are provided. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + + with patch.dict(os.environ, { + "AZURE_ENDPOINT": "https://test.openai.azure.com" + }, clear=True): + with pytest.raises(ValueError, match="Azure credentials are required"): + AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com") + + +def test_azure_static_token_credential(): + """ + Test that _StaticTokenCredential properly wraps a static token. + """ + from crewai.llms.providers.azure.completion import _StaticTokenCredential + from azure.core.credentials import AccessToken + + token = "my-static-token" + credential = _StaticTokenCredential(token) + + access_token = credential.get_token("https://cognitiveservices.azure.com/.default") + + assert isinstance(access_token, AccessToken) + assert access_token.token == token + assert access_token.expires_on > 0 + + +def test_azure_ad_token_env_var_used_when_no_api_key(): + """ + Test that AZURE_AD_TOKEN env var is used when AZURE_API_KEY is not set. + This reproduces the scenario from GitHub issue #4069. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + + with patch.dict(os.environ, { + "AZURE_AD_TOKEN": "token-from-provider", + "AZURE_ENDPOINT": "https://my-endpoint.openai.azure.com" + }, clear=True): + llm = LLM( + model="azure/gpt-4o-mini", + api_version="2024-02-01" + ) + + assert isinstance(llm, AzureCompletion) + assert llm.azure_ad_token == "token-from-provider" + assert llm.api_key is None + + +def test_azure_backward_compatibility_api_key(): + """ + Test that existing API key authentication still works (backward compatibility). + """ + from crewai.llms.providers.azure.completion import AzureCompletion + from azure.core.credentials import AzureKeyCredential + + with patch.dict(os.environ, { + "AZURE_API_KEY": "test-api-key", + "AZURE_ENDPOINT": "https://test.openai.azure.com" + }, clear=True): + llm = LLM(model="azure/gpt-4") + + assert isinstance(llm, AzureCompletion) + assert llm.api_key == "test-api-key" + assert llm.azure_ad_token is None + + +def test_azure_select_credential_returns_correct_type(): + """ + Test that _select_credential returns the correct credential type based on config. + """ + from crewai.llms.providers.azure.completion import AzureCompletion, _StaticTokenCredential + from azure.core.credentials import AzureKeyCredential + + with patch.dict(os.environ, { + "AZURE_ENDPOINT": "https://test.openai.azure.com" + }, clear=True): + llm_api_key = AzureCompletion( + model="gpt-4", + api_key="test-key", + endpoint="https://test.openai.azure.com" + ) + credential = llm_api_key._select_credential() + assert isinstance(credential, AzureKeyCredential) + + llm_ad_token = AzureCompletion( + model="gpt-4", + azure_ad_token="test-ad-token", + endpoint="https://test.openai.azure.com" + ) + credential = llm_ad_token._select_credential() + assert isinstance(credential, _StaticTokenCredential) + + +def test_azure_use_default_credential_import_error(): + """ + Test that use_default_credential raises ImportError when azure-identity is not available. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + import builtins + + original_import = builtins.__import__ + + def mock_import(name, *args, **kwargs): + if name == 'azure.identity': + raise ImportError("No module named 'azure.identity'") + return original_import(name, *args, **kwargs) + + with patch.dict(os.environ, { + "AZURE_ENDPOINT": "https://test.openai.azure.com" + }, clear=True): + with patch.object(builtins, '__import__', side_effect=mock_import): + with pytest.raises(ImportError, match="azure-identity package is required"): + AzureCompletion( + model="gpt-4", + endpoint="https://test.openai.azure.com", + use_default_credential=True + ) + + +def test_azure_improved_error_message_no_credentials(): + """ + Test that the error message when no credentials are provided is helpful. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError) as excinfo: + AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com") + + error_message = str(excinfo.value) + assert "Azure credentials are required" in error_message + assert "api_key" in error_message + assert "AZURE_API_KEY" in error_message + assert "azure_ad_token" in error_message + assert "AZURE_AD_TOKEN" in error_message + assert "credential" in error_message + assert "use_default_credential" in error_message