feat: support CLI login with Entra ID (#3943)

This commit is contained in:
Heitor Carvalho
2025-11-24 15:35:59 -03:00
committed by GitHub
parent 9da1f0c0aa
commit b759654e7d
7 changed files with 212 additions and 9 deletions

View File

@@ -67,7 +67,11 @@ class ProviderFactory:
module = importlib.import_module(
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
)
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
# Converts from snake_case to CamelCase to obtain the provider class name.
provider = getattr(
module,
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
)
return cast("BaseProvider", provider(settings))
@@ -91,7 +95,7 @@ class AuthenticationCommand:
device_code_payload = {
"client_id": self.oauth2_provider.get_client_id(),
"scope": "openid",
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
"audience": self.oauth2_provider.get_audience(),
}
response = requests.post(
@@ -104,9 +108,14 @@ class AuthenticationCommand:
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
"""Display the authentication instructions to the user."""
console.print("1. Navigate to: ", device_code_data["verification_uri_complete"])
verification_uri = device_code_data.get(
"verification_uri_complete", device_code_data.get("verification_uri", "")
)
console.print("1. Navigate to: ", verification_uri)
console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(device_code_data["verification_uri_complete"])
webbrowser.open(verification_uri)
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
"""Polls the server for the token until it is received, or max attempts are reached."""
@@ -186,8 +195,9 @@ class AuthenticationCommand:
)
settings = Settings()
console.print(
f"You are authenticated to the tool repository as [bold cyan]'{settings.org_name}'[/bold cyan] ({settings.org_uuid})",
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
style="green",
)
except Exception:

View File

@@ -28,3 +28,6 @@ class BaseProvider(ABC):
def get_required_fields(self) -> list[str]:
"""Returns which provider-specific fields inside the "extra" dict will be required"""
return []
def get_oauth_scopes(self) -> list[str]:
return ["openid", "profile", "email"]

View File

@@ -0,0 +1,43 @@
from typing import cast
from crewai.cli.authentication.providers.base_provider import BaseProvider
class EntraIdProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"{self._base_url()}/oauth2/v2.0/devicecode"
def get_token_url(self) -> str:
return f"{self._base_url()}/oauth2/v2.0/token"
def get_jwks_url(self) -> str:
return f"{self._base_url()}/discovery/v2.0/keys"
def get_issuer(self) -> str:
return f"{self._base_url()}/v2.0"
def get_audience(self) -> str:
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_oauth_scopes(self) -> list[str]:
return [
*super().get_oauth_scopes(),
*cast(str, self.settings.extra.get("scope", "")).split(),
]
def get_required_fields(self) -> list[str]:
return ["scope"]
def _base_url(self) -> str:
return f"https://login.microsoftonline.com/{self.settings.domain}"

View File

@@ -1,10 +1,12 @@
from typing import Any
import jwt
from jwt import PyJWKClient
def validate_jwt_token(
jwt_token: str, jwks_url: str, issuer: str, audience: str
) -> dict:
) -> Any:
"""
Verify the token's signature and claims using PyJWT.
:param jwt_token: The JWT (JWS) string to validate.
@@ -24,6 +26,7 @@ def validate_jwt_token(
_unverified_decoded_token = jwt.decode(
jwt_token, options={"verify_signature": False}
)
return jwt.decode(
jwt_token,
signing_key.key,

View File

@@ -162,7 +162,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
if login_response.status_code != 200:
console.print(
"Authentication failed. Verify access to the tool repository, or try `crewai login`. ",
"Authentication failed. Verify if the currently active organization access to the tool repository, and run 'crewai login' again. ",
style="bold red",
)
raise SystemExit

View File

@@ -0,0 +1,141 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.entra_id import EntraIdProvider
class TestEntraIdProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "openid profile email api://crewai-cli-dev/read"
}
)
self.provider = EntraIdProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = EntraIdProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "entra_id"
assert provider.settings.domain == "tenant-id-abcdef123456"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="my-company.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="another-domain.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="dev.entra.id",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
# For EntraID, the domain is the tenant ID.
settings = Oauth2Settings(
provider="entra_id",
domain="other-tenant-id-xpto",
client_id="test-client",
audience="test-audience",
)
provider = EntraIdProvider(settings)
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="entra_id",
domain="test-tenant-id",
client_id="test-client-id",
audience=None,
)
provider = EntraIdProvider(settings)
with pytest.raises(ValueError, match="Audience is required"):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert set(self.provider.get_required_fields()) == set(["scope"])
def test_get_oauth_scopes(self):
settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "api://crewai-cli-dev/read"
}
)
provider = EntraIdProvider(settings)
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
settings = Oauth2Settings(
provider="entra_id",
domain="tenant-id-abcdef123456",
client_id="test-client-id",
audience="test-audience",
extra={
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
}
)
provider = EntraIdProvider(settings)
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
def test_base_url(self):
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"

View File

@@ -15,6 +15,8 @@ class TestAuthenticationCommand:
def setup_method(self):
self.auth_command = AuthenticationCommand()
# TODO: these expectations are reading from the actual settings, we should mock them.
# E.g. if you change the client_id locally, this test will fail.
@pytest.mark.parametrize(
"user_provider,expected_urls",
[
@@ -181,7 +183,7 @@ class TestAuthenticationCommand:
),
call("Success!\n", style="bold green"),
call(
"You are authenticated to the tool repository as [bold cyan]'Test Org'[/bold cyan] (test-uuid-123)",
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
style="green",
),
]
@@ -234,6 +236,7 @@ class TestAuthenticationCommand:
"https://example.com/device"
)
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
result = self.auth_command._get_device_code()
@@ -241,7 +244,7 @@ class TestAuthenticationCommand:
url="https://example.com/device",
data={
"client_id": "test_client",
"scope": "openid",
"scope": "openid profile email",
"audience": "test_audience",
},
timeout=20,