diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index 7f3db08a8..e6d6c7ef3 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -1,8 +1,10 @@ from __future__ import annotations +from collections.abc import Callable import json import logging import os +import time from typing import TYPE_CHECKING, Any from pydantic import BaseModel @@ -35,7 +37,9 @@ try: StreamingChatCompletionsUpdate, ) from azure.core.credentials import ( + AccessToken, AzureKeyCredential, + TokenCredential, ) from azure.core.exceptions import ( HttpResponseError, @@ -50,6 +54,41 @@ except ImportError: ) from None +class _TokenProviderCredential(TokenCredential): + """Wrapper class to convert an azure_ad_token_provider callable into a TokenCredential. + + This allows users to pass a token provider function (like the one returned by + azure.identity.get_bearer_token_provider) to the Azure AI Inference client. + """ + + def __init__(self, provider: Callable[..., Any]): + """Initialize with a token provider callable. + + Args: + provider: A callable that returns an access token. This is typically + the result of azure.identity.get_bearer_token_provider(). + """ + self._provider = provider + + def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: + """Get an access token from the provider. + + Args: + *scopes: The scopes for the token (ignored, as the provider handles this). + **kwargs: Additional keyword arguments (ignored). + + Returns: + An AccessToken instance. + """ + raw = self._provider() + + if isinstance(raw, AccessToken): + return raw + + # If it's a bare string, wrap it with a default expiry of 1 hour + return AccessToken(str(raw), int(time.time()) + 3600) + + class AzureCompletion(BaseLLM): """Azure AI Inference native completion implementation. @@ -73,6 +112,8 @@ class AzureCompletion(BaseLLM): stop: list[str] | None = None, stream: bool = False, interceptor: BaseInterceptor[Any, Any] | None = None, + azure_ad_token_provider: Callable[..., Any] | None = None, + credential: TokenCredential | None = None, **kwargs: Any, ): """Initialize Azure AI Inference chat completion client. @@ -92,6 +133,13 @@ class AzureCompletion(BaseLLM): stop: Stop sequences stream: Enable streaming responses interceptor: HTTP interceptor (not yet supported for Azure). + azure_ad_token_provider: A callable that returns an Azure AD token. + This is typically the result of azure.identity.get_bearer_token_provider(). + Use this for Azure AD token-based authentication instead of API keys. + credential: An Azure TokenCredential instance for authentication. + This can be any credential from azure.identity (e.g., DefaultAzureCredential, + ManagedIdentityCredential). Takes precedence over azure_ad_token_provider + and api_key. **kwargs: Additional parameters """ if interceptor is not None: @@ -107,6 +155,7 @@ class AzureCompletion(BaseLLM): self.api_key = api_key or os.getenv("AZURE_API_KEY") self.endpoint = ( endpoint + or kwargs.get("base_url") or os.getenv("AZURE_ENDPOINT") or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("AZURE_API_BASE") @@ -115,31 +164,45 @@ class AzureCompletion(BaseLLM): self.timeout = timeout self.max_retries = max_retries - if not self.api_key: - raise ValueError( - "Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter." - ) if not self.endpoint: raise ValueError( "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." ) + # Determine the credential to use (priority: credential > azure_ad_token_provider > api_key) + chosen_credential: TokenCredential | AzureKeyCredential | None = None + + if credential is not None: + chosen_credential = credential + elif azure_ad_token_provider is not None: + chosen_credential = _TokenProviderCredential(azure_ad_token_provider) + elif self.api_key: + chosen_credential = AzureKeyCredential(self.api_key) + + if chosen_credential is None: + raise ValueError( + "Azure authentication is required. Provide one of: " + "api_key (or set AZURE_API_KEY environment variable), " + "azure_ad_token_provider (callable from azure.identity.get_bearer_token_provider), " + "or credential (TokenCredential instance from azure.identity)." + ) + # Validate and potentially fix Azure OpenAI endpoint URL self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model) # Build client kwargs - client_kwargs = { + client_kwargs: dict[str, Any] = { "endpoint": self.endpoint, - "credential": AzureKeyCredential(self.api_key), + "credential": chosen_credential, } # Add api_version if specified (primarily for Azure OpenAI endpoints) if self.api_version: client_kwargs["api_version"] = self.api_version - self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type] + self.client = ChatCompletionsClient(**client_kwargs) - self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type] + self.async_client = AsyncChatCompletionsClient(**client_kwargs) self.top_p = top_p self.frequency_penalty = frequency_penalty diff --git a/lib/crewai/tests/llms/azure/test_azure.py b/lib/crewai/tests/llms/azure/test_azure.py index ec881363d..e6b6efb58 100644 --- a/lib/crewai/tests/llms/azure/test_azure.py +++ b/lib/crewai/tests/llms/azure/test_azure.py @@ -382,13 +382,13 @@ def test_azure_raises_error_when_endpoint_missing(): AzureCompletion(model="gpt-4", api_key="test-key") -def test_azure_raises_error_when_api_key_missing(): - """Test that AzureCompletion raises ValueError when API key is missing""" +def test_azure_raises_error_when_no_auth_provided(): + """Test that AzureCompletion raises ValueError when no authentication is provided""" from crewai.llms.providers.azure.completion import AzureCompletion # Clear environment variables with patch.dict(os.environ, {}, clear=True): - with pytest.raises(ValueError, match="Azure API key is required"): + with pytest.raises(ValueError, match="Azure authentication is required"): AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com") @@ -1112,4 +1112,208 @@ def test_azure_completion_params_preparation_with_drop_params(): messages = [{"role": "user", "content": "Hello"}] params = llm._prepare_completion_params(messages) - assert params.get('stop') == None \ No newline at end of file + assert params.get('stop') == None + + +def test_azure_ad_token_provider_authentication(): + """ + Test that AzureCompletion can be initialized with azure_ad_token_provider + for Azure AD token-based authentication instead of API keys. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + + # Mock token provider that returns a string token + def mock_token_provider(): + return "mock-azure-ad-token" + + # Clear environment variables to ensure no API key is used + with patch.dict(os.environ, {}, clear=True): + completion = AzureCompletion( + model="gpt-4", + endpoint="https://test.openai.azure.com", + azure_ad_token_provider=mock_token_provider + ) + + assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4" + assert completion.api_key is None + + +def test_azure_ad_token_provider_with_access_token(): + """ + Test that azure_ad_token_provider works when it returns an AccessToken object. + """ + from crewai.llms.providers.azure.completion import AzureCompletion, _TokenProviderCredential + from azure.core.credentials import AccessToken + + # Mock token provider that returns an AccessToken object + mock_access_token = AccessToken("mock-token-string", 1234567890) + + def mock_token_provider(): + return mock_access_token + + # Test the _TokenProviderCredential wrapper + credential = _TokenProviderCredential(mock_token_provider) + token = credential.get_token("https://cognitiveservices.azure.com/.default") + + assert token.token == "mock-token-string" + assert token.expires_on == 1234567890 + + +def test_azure_ad_token_provider_with_string_token(): + """ + Test that azure_ad_token_provider works when it returns a plain string token. + """ + from crewai.llms.providers.azure.completion import _TokenProviderCredential + + # Mock token provider that returns a plain string + def mock_token_provider(): + return "plain-string-token" + + credential = _TokenProviderCredential(mock_token_provider) + token = credential.get_token("https://cognitiveservices.azure.com/.default") + + assert token.token == "plain-string-token" + # Should have a default expiry time (approximately 1 hour from now) + assert token.expires_on > 0 + + +def test_azure_credential_authentication(): + """ + Test that AzureCompletion can be initialized with a TokenCredential instance. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + from azure.core.credentials import AccessToken, TokenCredential + + # Create a mock TokenCredential + class MockTokenCredential(TokenCredential): + def get_token(self, *scopes, **kwargs): + return AccessToken("mock-credential-token", 1234567890) + + mock_credential = MockTokenCredential() + + # Clear environment variables to ensure no API key is used + with patch.dict(os.environ, {}, clear=True): + completion = AzureCompletion( + model="gpt-4", + endpoint="https://test.openai.azure.com", + credential=mock_credential + ) + + assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4" + assert completion.api_key is None + + +def test_azure_credential_takes_precedence_over_api_key(): + """ + Test that credential parameter takes precedence over api_key when both are provided. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + from azure.core.credentials import AccessToken, TokenCredential + + class MockTokenCredential(TokenCredential): + def get_token(self, *scopes, **kwargs): + return AccessToken("credential-token", 1234567890) + + mock_credential = MockTokenCredential() + + # Provide both credential and api_key + with patch.dict(os.environ, {}, clear=True): + completion = AzureCompletion( + model="gpt-4", + endpoint="https://test.openai.azure.com", + api_key="should-not-be-used", + credential=mock_credential + ) + + # The completion should be created successfully with the credential + assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4" + + +def test_azure_ad_token_provider_takes_precedence_over_api_key(): + """ + Test that azure_ad_token_provider takes precedence over api_key when both are provided. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + + def mock_token_provider(): + return "token-provider-token" + + # Provide both azure_ad_token_provider and api_key + with patch.dict(os.environ, {}, clear=True): + completion = AzureCompletion( + model="gpt-4", + endpoint="https://test.openai.azure.com", + api_key="should-not-be-used", + azure_ad_token_provider=mock_token_provider + ) + + # The completion should be created successfully with the token provider + assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4" + + +def test_azure_credential_takes_precedence_over_token_provider(): + """ + Test that credential takes precedence over azure_ad_token_provider when both are provided. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + from azure.core.credentials import AccessToken, TokenCredential + + class MockTokenCredential(TokenCredential): + def get_token(self, *scopes, **kwargs): + return AccessToken("credential-token", 1234567890) + + mock_credential = MockTokenCredential() + + def mock_token_provider(): + return "token-provider-token" + + # Provide both credential and azure_ad_token_provider + with patch.dict(os.environ, {}, clear=True): + completion = AzureCompletion( + model="gpt-4", + endpoint="https://test.openai.azure.com", + credential=mock_credential, + azure_ad_token_provider=mock_token_provider + ) + + # The completion should be created successfully with the credential + assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4" + + +def test_azure_ad_token_provider_via_llm_factory(): + """ + Test that azure_ad_token_provider can be passed through the LLM factory class. + """ + def mock_token_provider(): + return "mock-token" + + # Clear environment variables + with patch.dict(os.environ, { + "AZURE_ENDPOINT": "https://test.openai.azure.com" + }, clear=True): + llm = LLM( + model="azure/gpt-4", + azure_ad_token_provider=mock_token_provider + ) + + from crewai.llms.providers.azure.completion import AzureCompletion + assert isinstance(llm, AzureCompletion) + assert llm.api_key is None + + +def test_azure_base_url_parameter_for_endpoint(): + """ + Test that base_url parameter can be used as an alternative to endpoint. + This is useful for users migrating from other providers. + """ + from crewai.llms.providers.azure.completion import AzureCompletion + + # Clear environment variables + with patch.dict(os.environ, {}, clear=True): + completion = AzureCompletion( + model="gpt-4", + base_url="https://test.openai.azure.com", + api_key="test-key" + ) + + assert completion.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"