diff --git a/lib/crewai/src/crewai/cli/authentication/main.py b/lib/crewai/src/crewai/cli/authentication/main.py index 7bda8fe08..acc7f5c56 100644 --- a/lib/crewai/src/crewai/cli/authentication/main.py +++ b/lib/crewai/src/crewai/cli/authentication/main.py @@ -67,7 +67,11 @@ class ProviderFactory: module = importlib.import_module( f"crewai.cli.authentication.providers.{settings.provider.lower()}" ) - provider = getattr(module, f"{settings.provider.capitalize()}Provider") + # Converts from snake_case to CamelCase to obtain the provider class name. + provider = getattr( + module, + f"{''.join(word.capitalize() for word in settings.provider.split('_'))}Provider", + ) return cast("BaseProvider", provider(settings)) @@ -91,7 +95,7 @@ class AuthenticationCommand: device_code_payload = { "client_id": self.oauth2_provider.get_client_id(), - "scope": "openid", + "scope": " ".join(self.oauth2_provider.get_oauth_scopes()), "audience": self.oauth2_provider.get_audience(), } response = requests.post( @@ -104,9 +108,14 @@ class AuthenticationCommand: def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None: """Display the authentication instructions to the user.""" - console.print("1. Navigate to: ", device_code_data["verification_uri_complete"]) + + verification_uri = device_code_data.get( + "verification_uri_complete", device_code_data.get("verification_uri", "") + ) + + console.print("1. Navigate to: ", verification_uri) console.print("2. Enter the following code: ", device_code_data["user_code"]) - webbrowser.open(device_code_data["verification_uri_complete"]) + webbrowser.open(verification_uri) def _poll_for_token(self, device_code_data: dict[str, Any]) -> None: """Polls the server for the token until it is received, or max attempts are reached.""" @@ -186,8 +195,9 @@ class AuthenticationCommand: ) settings = Settings() + console.print( - f"You are authenticated to the tool repository as [bold cyan]'{settings.org_name}'[/bold cyan] ({settings.org_uuid})", + f"You are now authenticated to the tool repository for organization [bold cyan]'{settings.org_name if settings.org_name else settings.org_uuid}'[/bold cyan]", style="green", ) except Exception: diff --git a/lib/crewai/src/crewai/cli/authentication/providers/base_provider.py b/lib/crewai/src/crewai/cli/authentication/providers/base_provider.py index 0c8057b4d..9412ca283 100644 --- a/lib/crewai/src/crewai/cli/authentication/providers/base_provider.py +++ b/lib/crewai/src/crewai/cli/authentication/providers/base_provider.py @@ -28,3 +28,6 @@ class BaseProvider(ABC): def get_required_fields(self) -> list[str]: """Returns which provider-specific fields inside the "extra" dict will be required""" return [] + + def get_oauth_scopes(self) -> list[str]: + return ["openid", "profile", "email"] diff --git a/lib/crewai/src/crewai/cli/authentication/providers/entra_id.py b/lib/crewai/src/crewai/cli/authentication/providers/entra_id.py new file mode 100644 index 000000000..c08ea4ec7 --- /dev/null +++ b/lib/crewai/src/crewai/cli/authentication/providers/entra_id.py @@ -0,0 +1,43 @@ +from typing import cast + +from crewai.cli.authentication.providers.base_provider import BaseProvider + + +class EntraIdProvider(BaseProvider): + def get_authorize_url(self) -> str: + return f"{self._base_url()}/oauth2/v2.0/devicecode" + + def get_token_url(self) -> str: + return f"{self._base_url()}/oauth2/v2.0/token" + + def get_jwks_url(self) -> str: + return f"{self._base_url()}/discovery/v2.0/keys" + + def get_issuer(self) -> str: + return f"{self._base_url()}/v2.0" + + def get_audience(self) -> str: + if self.settings.audience is None: + raise ValueError( + "Audience is required. Please set it in the configuration." + ) + return self.settings.audience + + def get_client_id(self) -> str: + if self.settings.client_id is None: + raise ValueError( + "Client ID is required. Please set it in the configuration." + ) + return self.settings.client_id + + def get_oauth_scopes(self) -> list[str]: + return [ + *super().get_oauth_scopes(), + *cast(str, self.settings.extra.get("scope", "")).split(), + ] + + def get_required_fields(self) -> list[str]: + return ["scope"] + + def _base_url(self) -> str: + return f"https://login.microsoftonline.com/{self.settings.domain}" diff --git a/lib/crewai/src/crewai/cli/authentication/utils.py b/lib/crewai/src/crewai/cli/authentication/utils.py index 08955092b..7311b9d42 100644 --- a/lib/crewai/src/crewai/cli/authentication/utils.py +++ b/lib/crewai/src/crewai/cli/authentication/utils.py @@ -1,10 +1,12 @@ +from typing import Any + import jwt from jwt import PyJWKClient def validate_jwt_token( jwt_token: str, jwks_url: str, issuer: str, audience: str -) -> dict: +) -> Any: """ Verify the token's signature and claims using PyJWT. :param jwt_token: The JWT (JWS) string to validate. @@ -24,6 +26,7 @@ def validate_jwt_token( _unverified_decoded_token = jwt.decode( jwt_token, options={"verify_signature": False} ) + return jwt.decode( jwt_token, signing_key.key, diff --git a/lib/crewai/src/crewai/cli/tools/main.py b/lib/crewai/src/crewai/cli/tools/main.py index 2705388c5..13fd257fe 100644 --- a/lib/crewai/src/crewai/cli/tools/main.py +++ b/lib/crewai/src/crewai/cli/tools/main.py @@ -162,7 +162,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin): if login_response.status_code != 200: console.print( - "Authentication failed. Verify access to the tool repository, or try `crewai login`. ", + "Authentication failed. Verify if the currently active organization access to the tool repository, and run 'crewai login' again. ", style="bold red", ) raise SystemExit diff --git a/lib/crewai/tests/cli/authentication/providers/test_entra_id.py b/lib/crewai/tests/cli/authentication/providers/test_entra_id.py new file mode 100644 index 000000000..889023955 --- /dev/null +++ b/lib/crewai/tests/cli/authentication/providers/test_entra_id.py @@ -0,0 +1,141 @@ +import pytest + +from crewai.cli.authentication.main import Oauth2Settings +from crewai.cli.authentication.providers.entra_id import EntraIdProvider + + +class TestEntraIdProvider: + @pytest.fixture(autouse=True) + def setup_method(self): + self.valid_settings = Oauth2Settings( + provider="entra_id", + domain="tenant-id-abcdef123456", + client_id="test-client-id", + audience="test-audience", + extra={ + "scope": "openid profile email api://crewai-cli-dev/read" + } + ) + self.provider = EntraIdProvider(self.valid_settings) + + def test_initialization_with_valid_settings(self): + provider = EntraIdProvider(self.valid_settings) + assert provider.settings == self.valid_settings + assert provider.settings.provider == "entra_id" + assert provider.settings.domain == "tenant-id-abcdef123456" + assert provider.settings.client_id == "test-client-id" + assert provider.settings.audience == "test-audience" + + def test_get_authorize_url(self): + expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode" + assert self.provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_different_domain(self): + # For EntraID, the domain is the tenant ID. + settings = Oauth2Settings( + provider="entra_id", + domain="my-company.entra.id", + client_id="test-client", + audience="test-audience", + ) + provider = EntraIdProvider(settings) + expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode" + assert provider.get_authorize_url() == expected_url + + def test_get_token_url(self): + expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token" + assert self.provider.get_token_url() == expected_url + + def test_get_token_url_with_different_domain(self): + # For EntraID, the domain is the tenant ID. + settings = Oauth2Settings( + provider="entra_id", + domain="another-domain.entra.id", + client_id="test-client", + audience="test-audience", + ) + provider = EntraIdProvider(settings) + expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token" + assert provider.get_token_url() == expected_url + + def test_get_jwks_url(self): + expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys" + assert self.provider.get_jwks_url() == expected_url + + def test_get_jwks_url_with_different_domain(self): + # For EntraID, the domain is the tenant ID. + settings = Oauth2Settings( + provider="entra_id", + domain="dev.entra.id", + client_id="test-client", + audience="test-audience", + ) + provider = EntraIdProvider(settings) + expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys" + assert provider.get_jwks_url() == expected_url + + def test_get_issuer(self): + expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0" + assert self.provider.get_issuer() == expected_issuer + + def test_get_issuer_with_different_domain(self): + # For EntraID, the domain is the tenant ID. + settings = Oauth2Settings( + provider="entra_id", + domain="other-tenant-id-xpto", + client_id="test-client", + audience="test-audience", + ) + provider = EntraIdProvider(settings) + expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0" + assert provider.get_issuer() == expected_issuer + + def test_get_audience(self): + assert self.provider.get_audience() == "test-audience" + + def test_get_audience_assertion_error_when_none(self): + settings = Oauth2Settings( + provider="entra_id", + domain="test-tenant-id", + client_id="test-client-id", + audience=None, + ) + provider = EntraIdProvider(settings) + + with pytest.raises(ValueError, match="Audience is required"): + provider.get_audience() + + def test_get_client_id(self): + assert self.provider.get_client_id() == "test-client-id" + + def test_get_required_fields(self): + assert set(self.provider.get_required_fields()) == set(["scope"]) + + def test_get_oauth_scopes(self): + settings = Oauth2Settings( + provider="entra_id", + domain="tenant-id-abcdef123456", + client_id="test-client-id", + audience="test-audience", + extra={ + "scope": "api://crewai-cli-dev/read" + } + ) + provider = EntraIdProvider(settings) + assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"] + + def test_get_oauth_scopes_with_multiple_custom_scopes(self): + settings = Oauth2Settings( + provider="entra_id", + domain="tenant-id-abcdef123456", + client_id="test-client-id", + audience="test-audience", + extra={ + "scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2" + } + ) + provider = EntraIdProvider(settings) + assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"] + + def test_base_url(self): + assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456" \ No newline at end of file diff --git a/lib/crewai/tests/cli/authentication/test_auth_main.py b/lib/crewai/tests/cli/authentication/test_auth_main.py index d5d309ca9..5f7308e20 100644 --- a/lib/crewai/tests/cli/authentication/test_auth_main.py +++ b/lib/crewai/tests/cli/authentication/test_auth_main.py @@ -15,6 +15,8 @@ class TestAuthenticationCommand: def setup_method(self): self.auth_command = AuthenticationCommand() + # TODO: these expectations are reading from the actual settings, we should mock them. + # E.g. if you change the client_id locally, this test will fail. @pytest.mark.parametrize( "user_provider,expected_urls", [ @@ -181,7 +183,7 @@ class TestAuthenticationCommand: ), call("Success!\n", style="bold green"), call( - "You are authenticated to the tool repository as [bold cyan]'Test Org'[/bold cyan] (test-uuid-123)", + "You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]", style="green", ), ] @@ -234,6 +236,7 @@ class TestAuthenticationCommand: "https://example.com/device" ) self.auth_command.oauth2_provider.get_audience.return_value = "test_audience" + self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"] result = self.auth_command._get_device_code() @@ -241,7 +244,7 @@ class TestAuthenticationCommand: url="https://example.com/device", data={ "client_id": "test_client", - "scope": "openid", + "scope": "openid profile email", "audience": "test_audience", }, timeout=20,