mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 11:58:31 +00:00
feat: support CLI login with Entra ID (#3943)
This commit is contained in:
@@ -67,7 +67,11 @@ class ProviderFactory:
|
||||
module = importlib.import_module(
|
||||
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||
)
|
||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||
# Converts from snake_case to CamelCase to obtain the provider class name.
|
||||
provider = getattr(
|
||||
module,
|
||||
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
|
||||
)
|
||||
|
||||
return cast("BaseProvider", provider(settings))
|
||||
|
||||
@@ -91,7 +95,7 @@ class AuthenticationCommand:
|
||||
|
||||
device_code_payload = {
|
||||
"client_id": self.oauth2_provider.get_client_id(),
|
||||
"scope": "openid",
|
||||
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
response = requests.post(
|
||||
@@ -104,9 +108,14 @@ class AuthenticationCommand:
|
||||
|
||||
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
||||
"""Display the authentication instructions to the user."""
|
||||
console.print("1. Navigate to: ", device_code_data["verification_uri_complete"])
|
||||
|
||||
verification_uri = device_code_data.get(
|
||||
"verification_uri_complete", device_code_data.get("verification_uri", "")
|
||||
)
|
||||
|
||||
console.print("1. Navigate to: ", verification_uri)
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
webbrowser.open(verification_uri)
|
||||
|
||||
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
|
||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||
@@ -186,8 +195,9 @@ class AuthenticationCommand:
|
||||
)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
console.print(
|
||||
f"You are authenticated to the tool repository as [bold cyan]'{settings.org_name}'[/bold cyan] ({settings.org_uuid})",
|
||||
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
|
||||
style="green",
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -28,3 +28,6 @@ class BaseProvider(ABC):
|
||||
def get_required_fields(self) -> list[str]:
|
||||
"""Returns which provider-specific fields inside the "extra" dict will be required"""
|
||||
return []
|
||||
|
||||
def get_oauth_scopes(self) -> list[str]:
|
||||
return ["openid", "profile", "email"]
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from typing import cast
|
||||
|
||||
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||
|
||||
|
||||
class EntraIdProvider(BaseProvider):
|
||||
def get_authorize_url(self) -> str:
|
||||
return f"{self._base_url()}/oauth2/v2.0/devicecode"
|
||||
|
||||
def get_token_url(self) -> str:
|
||||
return f"{self._base_url()}/oauth2/v2.0/token"
|
||||
|
||||
def get_jwks_url(self) -> str:
|
||||
return f"{self._base_url()}/discovery/v2.0/keys"
|
||||
|
||||
def get_issuer(self) -> str:
|
||||
return f"{self._base_url()}/v2.0"
|
||||
|
||||
def get_audience(self) -> str:
|
||||
if self.settings.audience is None:
|
||||
raise ValueError(
|
||||
"Audience is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.audience
|
||||
|
||||
def get_client_id(self) -> str:
|
||||
if self.settings.client_id is None:
|
||||
raise ValueError(
|
||||
"Client ID is required. Please set it in the configuration."
|
||||
)
|
||||
return self.settings.client_id
|
||||
|
||||
def get_oauth_scopes(self) -> list[str]:
|
||||
return [
|
||||
*super().get_oauth_scopes(),
|
||||
*cast(str, self.settings.extra.get("scope", "")).split(),
|
||||
]
|
||||
|
||||
def get_required_fields(self) -> list[str]:
|
||||
return ["scope"]
|
||||
|
||||
def _base_url(self) -> str:
|
||||
return f"https://login.microsoftonline.com/{self.settings.domain}"
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
|
||||
|
||||
def validate_jwt_token(
|
||||
jwt_token: str, jwks_url: str, issuer: str, audience: str
|
||||
) -> dict:
|
||||
) -> Any:
|
||||
"""
|
||||
Verify the token's signature and claims using PyJWT.
|
||||
:param jwt_token: The JWT (JWS) string to validate.
|
||||
@@ -24,6 +26,7 @@ def validate_jwt_token(
|
||||
_unverified_decoded_token = jwt.decode(
|
||||
jwt_token, options={"verify_signature": False}
|
||||
)
|
||||
|
||||
return jwt.decode(
|
||||
jwt_token,
|
||||
signing_key.key,
|
||||
|
||||
@@ -162,7 +162,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
|
||||
if login_response.status_code != 200:
|
||||
console.print(
|
||||
"Authentication failed. Verify access to the tool repository, or try `crewai login`. ",
|
||||
"Authentication failed. Verify if the currently active organization access to the tool repository, and run 'crewai login' again. ",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
141
lib/crewai/tests/cli/authentication/providers/test_entra_id.py
Normal file
141
lib/crewai/tests/cli/authentication/providers/test_entra_id.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
|
||||
from crewai.cli.authentication.main import Oauth2Settings
|
||||
from crewai.cli.authentication.providers.entra_id import EntraIdProvider
|
||||
|
||||
|
||||
class TestEntraIdProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "openid profile email api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
self.provider = EntraIdProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = EntraIdProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "entra_id"
|
||||
assert provider.settings.domain == "tenant-id-abcdef123456"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="my-company.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="another-domain.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="dev.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="other-tenant-id-xpto",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_assertion_error_when_none(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="test-tenant-id",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
|
||||
with pytest.raises(ValueError, match="Audience is required"):
|
||||
provider.get_audience()
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert set(self.provider.get_required_fields()) == set(["scope"])
|
||||
|
||||
def test_get_oauth_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
|
||||
|
||||
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
|
||||
|
||||
def test_base_url(self):
|
||||
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"
|
||||
@@ -15,6 +15,8 @@ class TestAuthenticationCommand:
|
||||
def setup_method(self):
|
||||
self.auth_command = AuthenticationCommand()
|
||||
|
||||
# TODO: these expectations are reading from the actual settings, we should mock them.
|
||||
# E.g. if you change the client_id locally, this test will fail.
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,expected_urls",
|
||||
[
|
||||
@@ -181,7 +183,7 @@ class TestAuthenticationCommand:
|
||||
),
|
||||
call("Success!\n", style="bold green"),
|
||||
call(
|
||||
"You are authenticated to the tool repository as [bold cyan]'Test Org'[/bold cyan] (test-uuid-123)",
|
||||
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
|
||||
style="green",
|
||||
),
|
||||
]
|
||||
@@ -234,6 +236,7 @@ class TestAuthenticationCommand:
|
||||
"https://example.com/device"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
||||
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
|
||||
|
||||
result = self.auth_command._get_device_code()
|
||||
|
||||
@@ -241,7 +244,7 @@ class TestAuthenticationCommand:
|
||||
url="https://example.com/device",
|
||||
data={
|
||||
"client_id": "test_client",
|
||||
"scope": "openid",
|
||||
"scope": "openid profile email",
|
||||
"audience": "test_audience",
|
||||
},
|
||||
timeout=20,
|
||||
|
||||
Reference in New Issue
Block a user