mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
feat: support CLI login with Entra ID (#3943)
This commit is contained in:
@@ -67,7 +67,11 @@ class ProviderFactory:
|
|||||||
module = importlib.import_module(
|
module = importlib.import_module(
|
||||||
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||||
)
|
)
|
||||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
# Converts from snake_case to CamelCase to obtain the provider class name.
|
||||||
|
provider = getattr(
|
||||||
|
module,
|
||||||
|
f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider",
|
||||||
|
)
|
||||||
|
|
||||||
return cast("BaseProvider", provider(settings))
|
return cast("BaseProvider", provider(settings))
|
||||||
|
|
||||||
@@ -91,7 +95,7 @@ class AuthenticationCommand:
|
|||||||
|
|
||||||
device_code_payload = {
|
device_code_payload = {
|
||||||
"client_id": self.oauth2_provider.get_client_id(),
|
"client_id": self.oauth2_provider.get_client_id(),
|
||||||
"scope": "openid",
|
"scope": " ".join(self.oauth2_provider.get_oauth_scopes()),
|
||||||
"audience": self.oauth2_provider.get_audience(),
|
"audience": self.oauth2_provider.get_audience(),
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
@@ -104,9 +108,14 @@ class AuthenticationCommand:
|
|||||||
|
|
||||||
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
||||||
"""Display the authentication instructions to the user."""
|
"""Display the authentication instructions to the user."""
|
||||||
console.print("1. Navigate to: ", device_code_data["verification_uri_complete"])
|
|
||||||
|
verification_uri = device_code_data.get(
|
||||||
|
"verification_uri_complete", device_code_data.get("verification_uri", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
console.print("1. Navigate to: ", verification_uri)
|
||||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
webbrowser.open(verification_uri)
|
||||||
|
|
||||||
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
|
def _poll_for_token(self, device_code_data: dict[str, Any]) -> None:
|
||||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||||
@@ -186,8 +195,9 @@ class AuthenticationCommand:
|
|||||||
)
|
)
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
f"You are authenticated to the tool repository as [bold cyan]'{settings.org_name}'[/bold cyan] ({settings.org_uuid})",
|
f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]",
|
||||||
style="green",
|
style="green",
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -28,3 +28,6 @@ class BaseProvider(ABC):
|
|||||||
def get_required_fields(self) -> list[str]:
|
def get_required_fields(self) -> list[str]:
|
||||||
"""Returns which provider-specific fields inside the "extra" dict will be required"""
|
"""Returns which provider-specific fields inside the "extra" dict will be required"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def get_oauth_scopes(self) -> list[str]:
|
||||||
|
return ["openid", "profile", "email"]
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
|
class EntraIdProvider(BaseProvider):
|
||||||
|
def get_authorize_url(self) -> str:
|
||||||
|
return f"{self._base_url()}/oauth2/v2.0/devicecode"
|
||||||
|
|
||||||
|
def get_token_url(self) -> str:
|
||||||
|
return f"{self._base_url()}/oauth2/v2.0/token"
|
||||||
|
|
||||||
|
def get_jwks_url(self) -> str:
|
||||||
|
return f"{self._base_url()}/discovery/v2.0/keys"
|
||||||
|
|
||||||
|
def get_issuer(self) -> str:
|
||||||
|
return f"{self._base_url()}/v2.0"
|
||||||
|
|
||||||
|
def get_audience(self) -> str:
|
||||||
|
if self.settings.audience is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Audience is required. Please set it in the configuration."
|
||||||
|
)
|
||||||
|
return self.settings.audience
|
||||||
|
|
||||||
|
def get_client_id(self) -> str:
|
||||||
|
if self.settings.client_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Client ID is required. Please set it in the configuration."
|
||||||
|
)
|
||||||
|
return self.settings.client_id
|
||||||
|
|
||||||
|
def get_oauth_scopes(self) -> list[str]:
|
||||||
|
return [
|
||||||
|
*super().get_oauth_scopes(),
|
||||||
|
*cast(str, self.settings.extra.get("scope", "")).split(),
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_required_fields(self) -> list[str]:
|
||||||
|
return ["scope"]
|
||||||
|
|
||||||
|
def _base_url(self) -> str:
|
||||||
|
return f"https://login.microsoftonline.com/{self.settings.domain}"
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
from jwt import PyJWKClient
|
from jwt import PyJWKClient
|
||||||
|
|
||||||
|
|
||||||
def validate_jwt_token(
|
def validate_jwt_token(
|
||||||
jwt_token: str, jwks_url: str, issuer: str, audience: str
|
jwt_token: str, jwks_url: str, issuer: str, audience: str
|
||||||
) -> dict:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Verify the token's signature and claims using PyJWT.
|
Verify the token's signature and claims using PyJWT.
|
||||||
:param jwt_token: The JWT (JWS) string to validate.
|
:param jwt_token: The JWT (JWS) string to validate.
|
||||||
@@ -24,6 +26,7 @@ def validate_jwt_token(
|
|||||||
_unverified_decoded_token = jwt.decode(
|
_unverified_decoded_token = jwt.decode(
|
||||||
jwt_token, options={"verify_signature": False}
|
jwt_token, options={"verify_signature": False}
|
||||||
)
|
)
|
||||||
|
|
||||||
return jwt.decode(
|
return jwt.decode(
|
||||||
jwt_token,
|
jwt_token,
|
||||||
signing_key.key,
|
signing_key.key,
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
|||||||
|
|
||||||
if login_response.status_code != 200:
|
if login_response.status_code != 200:
|
||||||
console.print(
|
console.print(
|
||||||
"Authentication failed. Verify access to the tool repository, or try `crewai login`. ",
|
"Authentication failed. Verify if the currently active organization access to the tool repository, and run 'crewai login' again. ",
|
||||||
style="bold red",
|
style="bold red",
|
||||||
)
|
)
|
||||||
raise SystemExit
|
raise SystemExit
|
||||||
|
|||||||
141
lib/crewai/tests/cli/authentication/providers/test_entra_id.py
Normal file
141
lib/crewai/tests/cli/authentication/providers/test_entra_id.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.cli.authentication.main import Oauth2Settings
|
||||||
|
from crewai.cli.authentication.providers.entra_id import EntraIdProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestEntraIdProvider:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_method(self):
|
||||||
|
self.valid_settings = Oauth2Settings(
|
||||||
|
provider="entra_id",
|
||||||
|
domain="tenant-id-abcdef123456",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience="test-audience",
|
||||||
|
extra={
|
||||||
|
"scope": "openid profile email api://crewai-cli-dev/read"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.provider = EntraIdProvider(self.valid_settings)
|
||||||
|
|
||||||
|
def test_initialization_with_valid_settings(self):
|
||||||
|
provider = EntraIdProvider(self.valid_settings)
|
||||||
|
assert provider.settings == self.valid_settings
|
||||||
|
assert provider.settings.provider == "entra_id"
|
||||||
|
assert provider.settings.domain == "tenant-id-abcdef123456"
|
||||||
|
assert provider.settings.client_id == "test-client-id"
|
||||||
|
assert provider.settings.audience == "test-audience"
|
||||||
|
|
||||||
|
def test_get_authorize_url(self):
|
||||||
|
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
|
||||||
|
assert self.provider.get_authorize_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_authorize_url_with_different_domain(self):
|
||||||
|
# For EntraID, the domain is the tenant ID.
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="entra_id",
|
||||||
|
domain="my-company.entra.id",
|
||||||
|
client_id="test-client",
|
||||||
|
audience="test-audience",
|
||||||
|
)
|
||||||
|
provider = EntraIdProvider(settings)
|
||||||
|
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
|
||||||
|
assert provider.get_authorize_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_token_url(self):
|
||||||
|
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
|
||||||
|
assert self.provider.get_token_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_token_url_with_different_domain(self):
|
||||||
|
# For EntraID, the domain is the tenant ID.
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="entra_id",
|
||||||
|
domain="another-domain.entra.id",
|
||||||
|
client_id="test-client",
|
||||||
|
audience="test-audience",
|
||||||
|
)
|
||||||
|
provider = EntraIdProvider(settings)
|
||||||
|
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
|
||||||
|
assert provider.get_token_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_jwks_url(self):
|
||||||
|
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
|
||||||
|
assert self.provider.get_jwks_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_jwks_url_with_different_domain(self):
|
||||||
|
# For EntraID, the domain is the tenant ID.
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="entra_id",
|
||||||
|
domain="dev.entra.id",
|
||||||
|
client_id="test-client",
|
||||||
|
audience="test-audience",
|
||||||
|
)
|
||||||
|
provider = EntraIdProvider(settings)
|
||||||
|
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
|
||||||
|
assert provider.get_jwks_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_issuer(self):
|
||||||
|
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
|
||||||
|
assert self.provider.get_issuer() == expected_issuer
|
||||||
|
|
||||||
|
def test_get_issuer_with_different_domain(self):
|
||||||
|
# For EntraID, the domain is the tenant ID.
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="entra_id",
|
||||||
|
domain="other-tenant-id-xpto",
|
||||||
|
client_id="test-client",
|
||||||
|
audience="test-audience",
|
||||||
|
)
|
||||||
|
provider = EntraIdProvider(settings)
|
||||||
|
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
|
||||||
|
assert provider.get_issuer() == expected_issuer
|
||||||
|
|
||||||
|
def test_get_audience(self):
|
||||||
|
assert self.provider.get_audience() == "test-audience"
|
||||||
|
|
||||||
|
def test_get_audience_assertion_error_when_none(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="entra_id",
|
||||||
|
domain="test-tenant-id",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
)
|
||||||
|
provider = EntraIdProvider(settings)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Audience is required"):
|
||||||
|
provider.get_audience()
|
||||||
|
|
||||||
|
def test_get_client_id(self):
|
||||||
|
assert self.provider.get_client_id() == "test-client-id"
|
||||||
|
|
||||||
|
def test_get_required_fields(self):
|
||||||
|
assert set(self.provider.get_required_fields()) == set(["scope"])
|
||||||
|
|
||||||
|
def test_get_oauth_scopes(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="entra_id",
|
||||||
|
domain="tenant-id-abcdef123456",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience="test-audience",
|
||||||
|
extra={
|
||||||
|
"scope": "api://crewai-cli-dev/read"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = EntraIdProvider(settings)
|
||||||
|
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
|
||||||
|
|
||||||
|
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="entra_id",
|
||||||
|
domain="tenant-id-abcdef123456",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience="test-audience",
|
||||||
|
extra={
|
||||||
|
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = EntraIdProvider(settings)
|
||||||
|
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
|
||||||
|
|
||||||
|
def test_base_url(self):
|
||||||
|
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"
|
||||||
@@ -15,6 +15,8 @@ class TestAuthenticationCommand:
|
|||||||
def setup_method(self):
|
def setup_method(self):
|
||||||
self.auth_command = AuthenticationCommand()
|
self.auth_command = AuthenticationCommand()
|
||||||
|
|
||||||
|
# TODO: these expectations are reading from the actual settings, we should mock them.
|
||||||
|
# E.g. if you change the client_id locally, this test will fail.
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"user_provider,expected_urls",
|
"user_provider,expected_urls",
|
||||||
[
|
[
|
||||||
@@ -181,7 +183,7 @@ class TestAuthenticationCommand:
|
|||||||
),
|
),
|
||||||
call("Success!\n", style="bold green"),
|
call("Success!\n", style="bold green"),
|
||||||
call(
|
call(
|
||||||
"You are authenticated to the tool repository as [bold cyan]'Test Org'[/bold cyan] (test-uuid-123)",
|
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
|
||||||
style="green",
|
style="green",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -234,6 +236,7 @@ class TestAuthenticationCommand:
|
|||||||
"https://example.com/device"
|
"https://example.com/device"
|
||||||
)
|
)
|
||||||
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
||||||
|
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
|
||||||
|
|
||||||
result = self.auth_command._get_device_code()
|
result = self.auth_command._get_device_code()
|
||||||
|
|
||||||
@@ -241,7 +244,7 @@ class TestAuthenticationCommand:
|
|||||||
url="https://example.com/device",
|
url="https://example.com/device",
|
||||||
data={
|
data={
|
||||||
"client_id": "test_client",
|
"client_id": "test_client",
|
||||||
"scope": "openid",
|
"scope": "openid profile email",
|
||||||
"audience": "test_audience",
|
"audience": "test_audience",
|
||||||
},
|
},
|
||||||
timeout=20,
|
timeout=20,
|
||||||
|
|||||||
Reference in New Issue
Block a user