mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-08 02:29:00 +00:00
Compare commits
5 Commits
1.14.5a1
...
feat/azure
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff1af96cc8 | ||
|
|
96f6f89c9d | ||
|
|
cb4c89f6ce | ||
|
|
04a2d606a3 | ||
|
|
31f480aac0 |
@@ -94,6 +94,7 @@ google-genai = [
|
||||
]
|
||||
azure-ai-inference = [
|
||||
"azure-ai-inference~=1.0.0b9",
|
||||
"azure-identity>=1.17.0,<2",
|
||||
]
|
||||
anthropic = [
|
||||
"anthropic~=0.73.0",
|
||||
|
||||
@@ -35,6 +35,7 @@ try:
|
||||
)
|
||||
from azure.core.credentials import (
|
||||
AzureKeyCredential,
|
||||
TokenCredential,
|
||||
)
|
||||
from azure.core.exceptions import (
|
||||
HttpResponseError,
|
||||
@@ -88,6 +89,8 @@ class AzureCompletion(BaseLLM):
|
||||
response_format: type[BaseModel] | None = None
|
||||
is_openai_model: bool = False
|
||||
is_azure_openai_endpoint: bool = False
|
||||
azure_tenant_id: str | None = None
|
||||
azure_client_id: str | None = None
|
||||
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
@@ -115,6 +118,12 @@ class AzureCompletion(BaseLLM):
|
||||
data["api_version"] = (
|
||||
data.get("api_version") or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
)
|
||||
data["azure_tenant_id"] = data.get("azure_tenant_id") or os.getenv(
|
||||
"AZURE_TENANT_ID"
|
||||
)
|
||||
data["azure_client_id"] = data.get("azure_client_id") or os.getenv(
|
||||
"AZURE_CLIENT_ID"
|
||||
)
|
||||
|
||||
# Credentials and endpoint are validated lazily in `_init_clients`
|
||||
# so the LLM can be constructed before deployment env vars are set.
|
||||
@@ -149,7 +158,10 @@ class AzureCompletion(BaseLLM):
|
||||
try:
|
||||
self._client = self._build_sync_client()
|
||||
self._async_client = self._build_async_client()
|
||||
except ValueError:
|
||||
except (ValueError, ImportError):
|
||||
# Deferred initialization: client build is retried in
|
||||
# _ensure_clients() before the first API call, so it is
|
||||
# safe to suppress here when env vars are not yet set.
|
||||
pass
|
||||
return self
|
||||
|
||||
@@ -183,24 +195,96 @@ class AzureCompletion(BaseLLM):
|
||||
AzureCompletion._is_azure_openai_endpoint(self.endpoint)
|
||||
)
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment "
|
||||
"variable or pass api_key parameter."
|
||||
)
|
||||
# Re-read identity env vars for deferred builds
|
||||
if not self.azure_tenant_id:
|
||||
self.azure_tenant_id = os.getenv("AZURE_TENANT_ID")
|
||||
if not self.azure_client_id:
|
||||
self.azure_client_id = os.getenv("AZURE_CLIENT_ID")
|
||||
|
||||
if not self.endpoint:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment "
|
||||
"variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
credential = self._resolve_credential()
|
||||
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
"credential": credential,
|
||||
}
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
return client_kwargs
|
||||
|
||||
def _resolve_credential(self) -> AzureKeyCredential | TokenCredential:
|
||||
"""Resolve the Azure credential using a priority chain.
|
||||
|
||||
Token-based credentials are checked first because the platform's
|
||||
workload-identity manager sets env vars at runtime to enable
|
||||
keyless auth. When those vars are present they intentionally
|
||||
take precedence over any static ``api_key`` so that enterprises
|
||||
can enforce SP / Managed Identity policies.
|
||||
|
||||
Priority:
|
||||
1. OIDC federation (WorkloadIdentityCredential) — auto-discovered
|
||||
from AZURE_FEDERATED_TOKEN_FILE + AZURE_TENANT_ID + AZURE_CLIENT_ID
|
||||
2. Client secret (ClientSecretCredential) — explicit SP credentials
|
||||
3. Default chain (DefaultAzureCredential) — Managed Identity et al.
|
||||
4. API key fallback (AzureKeyCredential) — existing path
|
||||
"""
|
||||
federated_token_file = os.getenv("AZURE_FEDERATED_TOKEN_FILE")
|
||||
client_secret = os.getenv("AZURE_CLIENT_SECRET")
|
||||
|
||||
# Path 1: OIDC Workload Identity Federation
|
||||
if federated_token_file and self.azure_tenant_id and self.azure_client_id:
|
||||
try:
|
||||
from azure.identity import WorkloadIdentityCredential
|
||||
|
||||
return WorkloadIdentityCredential(
|
||||
tenant_id=self.azure_tenant_id,
|
||||
client_id=self.azure_client_id,
|
||||
token_file_path=federated_token_file,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-identity is required for workload identity federation. "
|
||||
'Install with: uv add "crewai[azure-ai-inference]"'
|
||||
) from None
|
||||
|
||||
# Path 2: Client Secret (Service Principal)
|
||||
if client_secret and self.azure_tenant_id and self.azure_client_id:
|
||||
try:
|
||||
from azure.identity import ClientSecretCredential
|
||||
|
||||
return ClientSecretCredential(
|
||||
tenant_id=self.azure_tenant_id,
|
||||
client_id=self.azure_client_id,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-identity is required for service principal authentication. "
|
||||
'Install with: uv add "crewai[azure-ai-inference]"'
|
||||
) from None
|
||||
|
||||
# Path 3: DefaultAzureCredential (Managed Identity, Azure CLI, etc.)
|
||||
# Only attempt if azure-identity is installed and no API key is available
|
||||
if not self.api_key:
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
|
||||
return DefaultAzureCredential()
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Azure API key is required when azure-identity is not installed. "
|
||||
"Set AZURE_API_KEY environment variable, pass api_key parameter, "
|
||||
'or install azure-identity: uv add "crewai[azure-ai-inference]"'
|
||||
) from None
|
||||
|
||||
# Path 4: API Key (existing path)
|
||||
return AzureKeyCredential(self.api_key)
|
||||
|
||||
def _get_sync_client(self) -> Any:
|
||||
if self._client is None:
|
||||
self._client = self._build_sync_client()
|
||||
|
||||
@@ -390,16 +390,26 @@ def test_azure_raises_error_when_endpoint_missing():
|
||||
|
||||
|
||||
def test_azure_raises_error_when_api_key_missing():
|
||||
"""Credentials are validated lazily: construction succeeds, first
|
||||
client build raises the descriptive error."""
|
||||
"""When no API key AND azure-identity is not installed, credentials
|
||||
are validated lazily: construction succeeds, first client build raises.
|
||||
With azure-identity installed, DefaultAzureCredential is used instead."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
llm = AzureCompletion(
|
||||
model="gpt-4", endpoint="https://test.openai.azure.com"
|
||||
)
|
||||
with pytest.raises(ValueError, match="Azure API key is required"):
|
||||
llm._get_sync_client()
|
||||
# With azure-identity installed, DefaultAzureCredential is used as
|
||||
# fallback instead of raising. Only raises when azure-identity is
|
||||
# not available.
|
||||
try:
|
||||
import azure.identity # noqa: F401
|
||||
# azure-identity is installed — DefaultAzureCredential will be used
|
||||
client = llm._get_sync_client()
|
||||
assert client is not None
|
||||
except ImportError:
|
||||
with pytest.raises(ValueError, match="Azure API key is required"):
|
||||
llm._get_sync_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
258
lib/crewai/tests/llms/azure/test_azure_credentials.py
Normal file
258
lib/crewai/tests/llms/azure/test_azure_credentials.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Tests for Azure credential resolution chain in AzureCompletion.
|
||||
|
||||
Covers the four credential paths:
|
||||
1. WorkloadIdentityCredential (OIDC federation)
|
||||
2. ClientSecretCredential (Service Principal)
|
||||
3. DefaultAzureCredential (Managed Identity / CLI fallback)
|
||||
4. AzureKeyCredential (API key - existing path)
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# Use a non-Azure-OpenAI endpoint to avoid _validate_and_fix_endpoint suffixing
|
||||
ENDPOINT = "https://test-ai.services.example.com"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _clear_azure_env(monkeypatch):
|
||||
"""Remove all Azure env vars to start clean."""
|
||||
for key in [
|
||||
"AZURE_API_KEY", "AZURE_ENDPOINT", "AZURE_OPENAI_ENDPOINT",
|
||||
"AZURE_API_BASE", "AZURE_API_VERSION", "AZURE_TENANT_ID",
|
||||
"AZURE_CLIENT_ID", "AZURE_CLIENT_SECRET", "AZURE_FEDERATED_TOKEN_FILE",
|
||||
]:
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_clear_azure_env")
|
||||
class TestCredentialResolution:
|
||||
"""Tests for AzureCompletion._resolve_credential."""
|
||||
|
||||
def test_api_key_credential_when_api_key_set(self):
|
||||
"""Path 4: API key produces AzureKeyCredential."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
completion = AzureCompletion(
|
||||
model="gpt-4",
|
||||
api_key="test-key",
|
||||
endpoint=ENDPOINT,
|
||||
)
|
||||
cred = completion._resolve_credential()
|
||||
assert isinstance(cred, AzureKeyCredential)
|
||||
|
||||
def test_api_key_from_env(self, monkeypatch):
|
||||
"""Path 4: api_key picked up from AZURE_API_KEY env var."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
monkeypatch.setenv("AZURE_API_KEY", "env-key")
|
||||
monkeypatch.setenv("AZURE_ENDPOINT", ENDPOINT)
|
||||
|
||||
completion = AzureCompletion(model="gpt-4")
|
||||
cred = completion._resolve_credential()
|
||||
assert isinstance(cred, AzureKeyCredential)
|
||||
|
||||
def test_workload_identity_credential(self, monkeypatch, tmp_path):
|
||||
"""Path 1: OIDC federation via WorkloadIdentityCredential."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("eyJhbGciOiJSUzI1NiJ9.test")
|
||||
|
||||
monkeypatch.setenv("AZURE_FEDERATED_TOKEN_FILE", str(token_file))
|
||||
monkeypatch.setenv("AZURE_ENDPOINT", ENDPOINT)
|
||||
|
||||
mock_wi_cred = MagicMock()
|
||||
with patch(
|
||||
"azure.identity.WorkloadIdentityCredential",
|
||||
return_value=mock_wi_cred,
|
||||
) as mock_cls:
|
||||
completion = AzureCompletion(
|
||||
model="gpt-4",
|
||||
azure_tenant_id="tenant-123",
|
||||
azure_client_id="client-456",
|
||||
)
|
||||
cred = completion._resolve_credential()
|
||||
assert cred is mock_wi_cred
|
||||
# Called at least once with the right args (init may also call it)
|
||||
mock_cls.assert_any_call(
|
||||
tenant_id="tenant-123",
|
||||
client_id="client-456",
|
||||
token_file_path=str(token_file),
|
||||
)
|
||||
|
||||
def test_workload_identity_from_env_vars(self, monkeypatch, tmp_path):
|
||||
"""Path 1: All WI fields discovered from environment."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("eyJhbGciOiJSUzI1NiJ9.test")
|
||||
|
||||
monkeypatch.setenv("AZURE_FEDERATED_TOKEN_FILE", str(token_file))
|
||||
monkeypatch.setenv("AZURE_TENANT_ID", "env-tenant")
|
||||
monkeypatch.setenv("AZURE_CLIENT_ID", "env-client")
|
||||
monkeypatch.setenv("AZURE_ENDPOINT", ENDPOINT)
|
||||
|
||||
mock_wi_cred = MagicMock()
|
||||
with patch(
|
||||
"azure.identity.WorkloadIdentityCredential",
|
||||
return_value=mock_wi_cred,
|
||||
) as mock_cls:
|
||||
completion = AzureCompletion(model="gpt-4")
|
||||
cred = completion._resolve_credential()
|
||||
assert cred is mock_wi_cred
|
||||
mock_cls.assert_any_call(
|
||||
tenant_id="env-tenant",
|
||||
client_id="env-client",
|
||||
token_file_path=str(token_file),
|
||||
)
|
||||
|
||||
def test_client_secret_credential(self, monkeypatch):
|
||||
"""Path 2: Service Principal with client secret."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
monkeypatch.setenv("AZURE_CLIENT_SECRET", "sp-secret")
|
||||
monkeypatch.setenv("AZURE_ENDPOINT", ENDPOINT)
|
||||
|
||||
mock_cs_cred = MagicMock()
|
||||
with patch(
|
||||
"azure.identity.ClientSecretCredential",
|
||||
return_value=mock_cs_cred,
|
||||
) as mock_cls:
|
||||
completion = AzureCompletion(
|
||||
model="gpt-4",
|
||||
azure_tenant_id="tenant-123",
|
||||
azure_client_id="client-456",
|
||||
)
|
||||
cred = completion._resolve_credential()
|
||||
assert cred is mock_cs_cred
|
||||
mock_cls.assert_any_call(
|
||||
tenant_id="tenant-123",
|
||||
client_id="client-456",
|
||||
client_secret="sp-secret",
|
||||
)
|
||||
|
||||
def test_default_azure_credential_when_no_api_key(self, monkeypatch):
|
||||
"""Path 3: DefaultAzureCredential when no api_key and no SP/WI vars."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
monkeypatch.setenv("AZURE_ENDPOINT", ENDPOINT)
|
||||
|
||||
mock_default_cred = MagicMock()
|
||||
with patch(
|
||||
"azure.identity.DefaultAzureCredential",
|
||||
return_value=mock_default_cred,
|
||||
):
|
||||
completion = AzureCompletion(model="gpt-4")
|
||||
cred = completion._resolve_credential()
|
||||
assert cred is mock_default_cred
|
||||
|
||||
def test_workload_identity_takes_priority_over_api_key(self, monkeypatch, tmp_path):
|
||||
"""WI credential should take priority even when api_key is also set."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("eyJhbGciOiJSUzI1NiJ9.test")
|
||||
|
||||
monkeypatch.setenv("AZURE_FEDERATED_TOKEN_FILE", str(token_file))
|
||||
monkeypatch.setenv("AZURE_API_KEY", "should-not-use-this")
|
||||
monkeypatch.setenv("AZURE_ENDPOINT", ENDPOINT)
|
||||
|
||||
mock_wi_cred = MagicMock()
|
||||
with patch(
|
||||
"azure.identity.WorkloadIdentityCredential",
|
||||
return_value=mock_wi_cred,
|
||||
):
|
||||
completion = AzureCompletion(
|
||||
model="gpt-4",
|
||||
azure_tenant_id="tenant-123",
|
||||
azure_client_id="client-456",
|
||||
)
|
||||
cred = completion._resolve_credential()
|
||||
assert cred is mock_wi_cred
|
||||
|
||||
def test_client_secret_takes_priority_over_api_key(self, monkeypatch):
|
||||
"""SP credential should take priority over API key."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
monkeypatch.setenv("AZURE_CLIENT_SECRET", "sp-secret")
|
||||
monkeypatch.setenv("AZURE_API_KEY", "should-not-use-this")
|
||||
monkeypatch.setenv("AZURE_ENDPOINT", ENDPOINT)
|
||||
|
||||
mock_cs_cred = MagicMock()
|
||||
with patch(
|
||||
"azure.identity.ClientSecretCredential",
|
||||
return_value=mock_cs_cred,
|
||||
):
|
||||
completion = AzureCompletion(
|
||||
model="gpt-4",
|
||||
azure_tenant_id="tenant-123",
|
||||
azure_client_id="client-456",
|
||||
)
|
||||
cred = completion._resolve_credential()
|
||||
assert cred is mock_cs_cred
|
||||
|
||||
def test_raises_when_no_api_key_and_no_azure_identity(self, monkeypatch):
|
||||
"""ValueError when no api_key and azure-identity not installed."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
monkeypatch.setenv("AZURE_ENDPOINT", ENDPOINT)
|
||||
|
||||
with patch.dict("sys.modules", {"azure.identity": None}):
|
||||
completion = AzureCompletion(model="gpt-4")
|
||||
with pytest.raises(ValueError, match="Azure API key is required"):
|
||||
completion._resolve_credential()
|
||||
|
||||
def test_endpoint_still_required(self, monkeypatch, tmp_path):
|
||||
"""Endpoint is always required regardless of credential type."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("test-jwt")
|
||||
|
||||
monkeypatch.setenv("AZURE_FEDERATED_TOKEN_FILE", str(token_file))
|
||||
monkeypatch.setenv("AZURE_TENANT_ID", "tenant-123")
|
||||
monkeypatch.setenv("AZURE_CLIENT_ID", "client-456")
|
||||
|
||||
completion = AzureCompletion(model="gpt-4")
|
||||
with pytest.raises(ValueError, match="Azure endpoint is required"):
|
||||
completion._make_client_kwargs()
|
||||
|
||||
def test_deferred_build_picks_up_wi_env_vars(self, monkeypatch, tmp_path):
|
||||
"""Env vars set after construction are picked up on deferred build."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
# Construct with endpoint only — no credentials yet
|
||||
monkeypatch.setenv("AZURE_ENDPOINT", ENDPOINT)
|
||||
completion = AzureCompletion(model="gpt-4")
|
||||
|
||||
# Now set WI env vars (simulating WI manager setting them before crew run)
|
||||
token_file = tmp_path / "token.txt"
|
||||
token_file.write_text("eyJhbGciOiJSUzI1NiJ9.deferred")
|
||||
monkeypatch.setenv("AZURE_FEDERATED_TOKEN_FILE", str(token_file))
|
||||
monkeypatch.setenv("AZURE_TENANT_ID", "deferred-tenant")
|
||||
monkeypatch.setenv("AZURE_CLIENT_ID", "deferred-client")
|
||||
|
||||
mock_wi_cred = MagicMock()
|
||||
with patch(
|
||||
"azure.identity.WorkloadIdentityCredential",
|
||||
return_value=mock_wi_cred,
|
||||
):
|
||||
kwargs = completion._make_client_kwargs()
|
||||
assert kwargs["credential"] is mock_wi_cred
|
||||
|
||||
def test_make_client_kwargs_includes_api_version(self, monkeypatch):
|
||||
"""api_version is included in client kwargs."""
|
||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||
|
||||
monkeypatch.setenv("AZURE_API_KEY", "test-key")
|
||||
monkeypatch.setenv("AZURE_ENDPOINT", ENDPOINT)
|
||||
|
||||
completion = AzureCompletion(model="gpt-4", api_version="2025-01-01")
|
||||
kwargs = completion._make_client_kwargs()
|
||||
assert kwargs["api_version"] == "2025-01-01"
|
||||
assert kwargs["endpoint"] == ENDPOINT
|
||||
Reference in New Issue
Block a user