mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 15:52:34 +00:00
feat: Add Azure AD token authentication support for Azure provider
This commit adds support for Azure AD token authentication (Microsoft Entra ID) to the Azure AI Inference native provider, addressing issue #4069. Changes: - Add credential parameter for passing TokenCredential directly - Add azure_ad_token parameter and AZURE_AD_TOKEN env var support - Add use_default_credential flag for DefaultAzureCredential - Add _StaticTokenCredential class for wrapping static tokens - Add _select_credential method with clear priority order - Update error messages to reflect all authentication options - Add comprehensive tests for all new authentication methods Authentication Priority: 1. credential parameter (explicit TokenCredential) 2. azure_ad_token parameter or AZURE_AD_TOKEN env var 3. api_key parameter or AZURE_API_KEY env var 4. use_default_credential=True (DefaultAzureCredential) Fixes #4069 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -389,12 +389,12 @@ def test_azure_raises_error_when_endpoint_missing():
|
||||
|
||||
|
||||
def test_azure_raises_error_when_api_key_missing():
|
||||
"""Test that AzureCompletion raises ValueError when API key is missing"""
|
||||
"""Test that AzureCompletion raises ValueError when no credentials are provided"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
# Clear environment variables
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError, match="Azure API key is required"):
|
||||
with pytest.raises(ValueError, match="Azure credentials are required"):
|
||||
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||
|
||||
|
||||
@@ -1127,3 +1127,288 @@ def test_azure_streaming_returns_usage_metrics():
|
||||
assert result.token_usage.prompt_tokens > 0
|
||||
assert result.token_usage.completion_tokens > 0
|
||||
assert result.token_usage.successful_requests >= 1
|
||||
|
||||
|
||||
def test_azure_ad_token_authentication():
|
||||
"""
|
||||
Test that Azure AD token authentication works via AZURE_AD_TOKEN env var.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion, _StaticTokenCredential
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_AD_TOKEN": "test-ad-token",
|
||||
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||
}, clear=True):
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
assert llm.azure_ad_token == "test-ad-token"
|
||||
assert llm.api_key is None
|
||||
|
||||
|
||||
def test_azure_ad_token_parameter():
|
||||
"""
|
||||
Test that azure_ad_token parameter works for Azure AD authentication.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
llm = LLM(
|
||||
model="azure/gpt-4",
|
||||
azure_ad_token="my-ad-token",
|
||||
endpoint="https://test.openai.azure.com"
|
||||
)
|
||||
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
assert llm.azure_ad_token == "my-ad-token"
|
||||
|
||||
|
||||
def test_azure_credential_parameter():
|
||||
"""
|
||||
Test that credential parameter works for passing TokenCredential directly.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
class MockTokenCredential:
|
||||
def get_token(self, *scopes, **kwargs):
|
||||
from azure.core.credentials import AccessToken
|
||||
return AccessToken("mock-token", 9999999999)
|
||||
|
||||
mock_credential = MockTokenCredential()
|
||||
|
||||
llm = LLM(
|
||||
model="azure/gpt-4",
|
||||
credential=mock_credential,
|
||||
endpoint="https://test.openai.azure.com"
|
||||
)
|
||||
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
assert llm._explicit_credential is mock_credential
|
||||
|
||||
|
||||
def test_azure_use_default_credential():
|
||||
"""
|
||||
Test that use_default_credential=True uses DefaultAzureCredential.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
azure_identity_available = True
|
||||
except ImportError:
|
||||
azure_identity_available = False
|
||||
|
||||
if azure_identity_available:
|
||||
with patch('azure.identity.DefaultAzureCredential') as mock_default_cred:
|
||||
mock_default_cred.return_value = MagicMock()
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||
}, clear=True):
|
||||
llm = LLM(
|
||||
model="azure/gpt-4",
|
||||
use_default_credential=True
|
||||
)
|
||||
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
assert llm.use_default_credential is True
|
||||
mock_default_cred.assert_called_once()
|
||||
else:
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||
}, clear=True):
|
||||
with pytest.raises(ImportError, match="azure-identity package is required"):
|
||||
LLM(
|
||||
model="azure/gpt-4",
|
||||
use_default_credential=True
|
||||
)
|
||||
|
||||
|
||||
def test_azure_credential_priority_explicit_credential_first():
|
||||
"""
|
||||
Test that explicit credential takes priority over other auth methods.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
class MockTokenCredential:
|
||||
def get_token(self, *scopes, **kwargs):
|
||||
from azure.core.credentials import AccessToken
|
||||
return AccessToken("mock-token", 9999999999)
|
||||
|
||||
mock_credential = MockTokenCredential()
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_API_KEY": "test-key",
|
||||
"AZURE_AD_TOKEN": "test-ad-token",
|
||||
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||
}):
|
||||
llm = LLM(
|
||||
model="azure/gpt-4",
|
||||
credential=mock_credential,
|
||||
api_key="another-key",
|
||||
azure_ad_token="another-token"
|
||||
)
|
||||
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
assert llm._explicit_credential is mock_credential
|
||||
|
||||
|
||||
def test_azure_credential_priority_ad_token_over_api_key():
|
||||
"""
|
||||
Test that azure_ad_token takes priority over api_key.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion, _StaticTokenCredential
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||
}, clear=True):
|
||||
llm = LLM(
|
||||
model="azure/gpt-4",
|
||||
azure_ad_token="my-ad-token",
|
||||
api_key="my-api-key"
|
||||
)
|
||||
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
assert llm.azure_ad_token == "my-ad-token"
|
||||
assert llm.api_key == "my-api-key"
|
||||
|
||||
|
||||
def test_azure_raises_error_when_no_credentials():
|
||||
"""
|
||||
Test that AzureCompletion raises ValueError when no credentials are provided.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||
}, clear=True):
|
||||
with pytest.raises(ValueError, match="Azure credentials are required"):
|
||||
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||
|
||||
|
||||
def test_azure_static_token_credential():
|
||||
"""
|
||||
Test that _StaticTokenCredential properly wraps a static token.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import _StaticTokenCredential
|
||||
from azure.core.credentials import AccessToken
|
||||
|
||||
token = "my-static-token"
|
||||
credential = _StaticTokenCredential(token)
|
||||
|
||||
access_token = credential.get_token("https://cognitiveservices.azure.com/.default")
|
||||
|
||||
assert isinstance(access_token, AccessToken)
|
||||
assert access_token.token == token
|
||||
assert access_token.expires_on > 0
|
||||
|
||||
|
||||
def test_azure_ad_token_env_var_used_when_no_api_key():
|
||||
"""
|
||||
Test that AZURE_AD_TOKEN env var is used when AZURE_API_KEY is not set.
|
||||
This reproduces the scenario from GitHub issue #4069.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_AD_TOKEN": "token-from-provider",
|
||||
"AZURE_ENDPOINT": "https://my-endpoint.openai.azure.com"
|
||||
}, clear=True):
|
||||
llm = LLM(
|
||||
model="azure/gpt-4o-mini",
|
||||
api_version="2024-02-01"
|
||||
)
|
||||
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
assert llm.azure_ad_token == "token-from-provider"
|
||||
assert llm.api_key is None
|
||||
|
||||
|
||||
def test_azure_backward_compatibility_api_key():
|
||||
"""
|
||||
Test that existing API key authentication still works (backward compatibility).
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_API_KEY": "test-api-key",
|
||||
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||
}, clear=True):
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
assert isinstance(llm, AzureCompletion)
|
||||
assert llm.api_key == "test-api-key"
|
||||
assert llm.azure_ad_token is None
|
||||
|
||||
|
||||
def test_azure_select_credential_returns_correct_type():
|
||||
"""
|
||||
Test that _select_credential returns the correct credential type based on config.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion, _StaticTokenCredential
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||
}, clear=True):
|
||||
llm_api_key = AzureCompletion(
|
||||
model="gpt-4",
|
||||
api_key="test-key",
|
||||
endpoint="https://test.openai.azure.com"
|
||||
)
|
||||
credential = llm_api_key._select_credential()
|
||||
assert isinstance(credential, AzureKeyCredential)
|
||||
|
||||
llm_ad_token = AzureCompletion(
|
||||
model="gpt-4",
|
||||
azure_ad_token="test-ad-token",
|
||||
endpoint="https://test.openai.azure.com"
|
||||
)
|
||||
credential = llm_ad_token._select_credential()
|
||||
assert isinstance(credential, _StaticTokenCredential)
|
||||
|
||||
|
||||
def test_azure_use_default_credential_import_error():
|
||||
"""
|
||||
Test that use_default_credential raises ImportError when azure-identity is not available.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
import builtins
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == 'azure.identity':
|
||||
raise ImportError("No module named 'azure.identity'")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||
}, clear=True):
|
||||
with patch.object(builtins, '__import__', side_effect=mock_import):
|
||||
with pytest.raises(ImportError, match="azure-identity package is required"):
|
||||
AzureCompletion(
|
||||
model="gpt-4",
|
||||
endpoint="https://test.openai.azure.com",
|
||||
use_default_credential=True
|
||||
)
|
||||
|
||||
|
||||
def test_azure_improved_error_message_no_credentials():
|
||||
"""
|
||||
Test that the error message when no credentials are provided is helpful.
|
||||
"""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||
|
||||
error_message = str(excinfo.value)
|
||||
assert "Azure credentials are required" in error_message
|
||||
assert "api_key" in error_message
|
||||
assert "AZURE_API_KEY" in error_message
|
||||
assert "azure_ad_token" in error_message
|
||||
assert "AZURE_AD_TOKEN" in error_message
|
||||
assert "credential" in error_message
|
||||
assert "use_default_credential" in error_message
|
||||
|
||||
Reference in New Issue
Block a user