diff --git a/src/crewai/cli/authentication/constants.py b/src/crewai/cli/authentication/constants.py index 0616d5407..22a212822 100644 --- a/src/crewai/cli/authentication/constants.py +++ b/src/crewai/cli/authentication/constants.py @@ -1,8 +1,6 @@ ALGORITHMS = ["RS256"] + +#TODO: The AUTH0 constants should be removed after WorkOS migration is completed AUTH0_DOMAIN = "crewai.us.auth0.com" AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8" AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/" - -WORKOS_DOMAIN = "login.crewai.com" -WORKOS_CLI_CONNECT_APP_ID = "client_01JYT06R59SP0NXYGD994NFXXX" -WORKOS_ENVIRONMENT_ID = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8" diff --git a/src/crewai/cli/authentication/main.py b/src/crewai/cli/authentication/main.py index 303c7a1fe..26a354aea 100644 --- a/src/crewai/cli/authentication/main.py +++ b/src/crewai/cli/authentication/main.py @@ -1,76 +1,92 @@ import time import webbrowser -from typing import Any, Dict +from typing import Any, Dict, Optional import requests from rich.console import Console +from pydantic import BaseModel, Field -from .constants import ( - AUTH0_AUDIENCE, - AUTH0_CLIENT_ID, - AUTH0_DOMAIN, - WORKOS_DOMAIN, - WORKOS_CLI_CONNECT_APP_ID, - WORKOS_ENVIRONMENT_ID, -) from .utils import TokenManager, validate_jwt_token from urllib.parse import quote from crewai.cli.plus_api import PlusAPI from crewai.cli.config import Settings +from crewai.cli.authentication.constants import ( + AUTH0_AUDIENCE, + AUTH0_CLIENT_ID, + AUTH0_DOMAIN, +) console = Console() +class Oauth2Settings(BaseModel): + provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).") + client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.") + domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.") + audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None) + + @classmethod + def from_settings(cls): + settings = Settings() + + return cls( + provider=settings.oauth2_provider, + domain=settings.oauth2_domain, + client_id=settings.oauth2_client_id, + audience=settings.oauth2_audience, + ) + + +class ProviderFactory: + @classmethod + def from_settings(cls, settings: Optional[Oauth2Settings] = None): + settings = settings or Oauth2Settings.from_settings() + + import importlib + module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}") + provider = getattr(module, f"{settings.provider.capitalize()}Provider") + + return provider(settings) + class AuthenticationCommand: - AUTH0_DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code" - AUTH0_TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token" - - WORKOS_DEVICE_CODE_URL = f"https://{WORKOS_DOMAIN}/oauth2/device_authorization" - WORKOS_TOKEN_URL = f"https://{WORKOS_DOMAIN}/oauth2/token" - def __init__(self): self.token_manager = TokenManager() - # TODO: WORKOS - This variable is temporary until migration to WorkOS is complete. - self.user_provider = "workos" + self.oauth2_provider = ProviderFactory.from_settings() def login(self) -> None: """Sign up to CrewAI+""" - - device_code_url = self.WORKOS_DEVICE_CODE_URL - token_url = self.WORKOS_TOKEN_URL - client_id = WORKOS_CLI_CONNECT_APP_ID - audience = None - console.print("Signing in to CrewAI Enterprise...\n", style="bold blue") # TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete. user_provider = self._determine_user_provider() if user_provider == "auth0": - device_code_url = self.AUTH0_DEVICE_CODE_URL - token_url = self.AUTH0_TOKEN_URL - client_id = AUTH0_CLIENT_ID - audience = AUTH0_AUDIENCE - self.user_provider = "auth0" + settings = Oauth2Settings( + provider="auth0", + client_id=AUTH0_CLIENT_ID, + domain=AUTH0_DOMAIN, + audience=AUTH0_AUDIENCE + ) + self.oauth2_provider = ProviderFactory.from_settings(settings) # End of temporary code. - device_code_data = self._get_device_code(client_id, device_code_url, audience) + device_code_data = self._get_device_code() self._display_auth_instructions(device_code_data) - return self._poll_for_token(device_code_data, client_id, token_url) + return self._poll_for_token(device_code_data) def _get_device_code( - self, client_id: str, device_code_url: str, audience: str | None = None + self ) -> Dict[str, Any]: """Get the device code to authenticate the user.""" device_code_payload = { - "client_id": client_id, + "client_id": self.oauth2_provider.get_client_id(), "scope": "openid", - "audience": audience, + "audience": self.oauth2_provider.get_audience(), } response = requests.post( - url=device_code_url, data=device_code_payload, timeout=20 + url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20 ) response.raise_for_status() return response.json() @@ -82,21 +98,21 @@ class AuthenticationCommand: webbrowser.open(device_code_data["verification_uri_complete"]) def _poll_for_token( - self, device_code_data: Dict[str, Any], client_id: str, token_poll_url: str + self, device_code_data: Dict[str, Any] ) -> None: """Polls the server for the token until it is received, or max attempts are reached.""" token_payload = { "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code_data["device_code"], - "client_id": client_id, + "client_id": self.oauth2_provider.get_client_id(), } console.print("\nWaiting for authentication... ", style="bold blue", end="") attempts = 0 while True and attempts < 10: - response = requests.post(token_poll_url, data=token_payload, timeout=30) + response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30) token_data = response.json() if response.status_code == 200: @@ -128,19 +144,14 @@ class AuthenticationCommand: """Validates the JWT token and saves the token to the token manager.""" jwt_token = token_data["access_token"] + issuer = self.oauth2_provider.get_issuer() jwt_token_data = { "jwt_token": jwt_token, - "jwks_url": f"https://{WORKOS_DOMAIN}/oauth2/jwks", - "issuer": f"https://{WORKOS_DOMAIN}", - "audience": WORKOS_ENVIRONMENT_ID, + "jwks_url": self.oauth2_provider.get_jwks_url(), + "issuer": issuer, + "audience": self.oauth2_provider.get_audience(), } - # TODO: WORKOS - The following conditional is temporary until migration to WorkOS is complete. - if self.user_provider == "auth0": - jwt_token_data["jwks_url"] = f"https://{AUTH0_DOMAIN}/.well-known/jwks.json" - jwt_token_data["issuer"] = f"https://{AUTH0_DOMAIN}/" - jwt_token_data["audience"] = AUTH0_AUDIENCE - decoded_token = validate_jwt_token(**jwt_token_data) expires_at = decoded_token.get("exp", 0) diff --git a/src/crewai/cli/authentication/providers/auth0.py b/src/crewai/cli/authentication/providers/auth0.py new file mode 100644 index 000000000..8538550db --- /dev/null +++ b/src/crewai/cli/authentication/providers/auth0.py @@ -0,0 +1,26 @@ +from crewai.cli.authentication.providers.base_provider import BaseProvider + +class Auth0Provider(BaseProvider): + def get_authorize_url(self) -> str: + return f"https://{self._get_domain()}/oauth/device/code" + + def get_token_url(self) -> str: + return f"https://{self._get_domain()}/oauth/token" + + def get_jwks_url(self) -> str: + return f"https://{self._get_domain()}/.well-known/jwks.json" + + def get_issuer(self) -> str: + return f"https://{self._get_domain()}/" + + def get_audience(self) -> str: + assert self.settings.audience is not None, "Audience is required" + return self.settings.audience + + def get_client_id(self) -> str: + assert self.settings.client_id is not None, "Client ID is required" + return self.settings.client_id + + def _get_domain(self) -> str: + assert self.settings.domain is not None, "Domain is required" + return self.settings.domain diff --git a/src/crewai/cli/authentication/providers/base_provider.py b/src/crewai/cli/authentication/providers/base_provider.py new file mode 100644 index 000000000..c321de9f7 --- /dev/null +++ b/src/crewai/cli/authentication/providers/base_provider.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from crewai.cli.authentication.main import Oauth2Settings + +class BaseProvider(ABC): + def __init__(self, settings: Oauth2Settings): + self.settings = settings + + @abstractmethod + def get_authorize_url(self) -> str: + ... + + @abstractmethod + def get_token_url(self) -> str: + ... + + @abstractmethod + def get_jwks_url(self) -> str: + ... + + @abstractmethod + def get_issuer(self) -> str: + ... + + @abstractmethod + def get_audience(self) -> str: + ... + + @abstractmethod + def get_client_id(self) -> str: + ... diff --git a/src/crewai/cli/authentication/providers/okta.py b/src/crewai/cli/authentication/providers/okta.py new file mode 100644 index 000000000..14227ae2b --- /dev/null +++ b/src/crewai/cli/authentication/providers/okta.py @@ -0,0 +1,22 @@ +from crewai.cli.authentication.providers.base_provider import BaseProvider + +class OktaProvider(BaseProvider): + def get_authorize_url(self) -> str: + return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize" + + def get_token_url(self) -> str: + return f"https://{self.settings.domain}/oauth2/default/v1/token" + + def get_jwks_url(self) -> str: + return f"https://{self.settings.domain}/oauth2/default/v1/keys" + + def get_issuer(self) -> str: + return f"https://{self.settings.domain}/oauth2/default" + + def get_audience(self) -> str: + assert self.settings.audience is not None + return self.settings.audience + + def get_client_id(self) -> str: + assert self.settings.client_id is not None + return self.settings.client_id diff --git a/src/crewai/cli/authentication/providers/workos.py b/src/crewai/cli/authentication/providers/workos.py new file mode 100644 index 000000000..8cf475a4d --- /dev/null +++ b/src/crewai/cli/authentication/providers/workos.py @@ -0,0 +1,25 @@ +from crewai.cli.authentication.providers.base_provider import BaseProvider + +class WorkosProvider(BaseProvider): + def get_authorize_url(self) -> str: + return f"https://{self._get_domain()}/oauth2/device_authorization" + + def get_token_url(self) -> str: + return f"https://{self._get_domain()}/oauth2/token" + + def get_jwks_url(self) -> str: + return f"https://{self._get_domain()}/oauth2/jwks" + + def get_issuer(self) -> str: + return f"https://{self._get_domain()}" + + def get_audience(self) -> str: + return self.settings.audience or "" + + def get_client_id(self) -> str: + assert self.settings.client_id is not None, "Client ID is required" + return self.settings.client_id + + def _get_domain(self) -> str: + assert self.settings.domain is not None, "Domain is required" + return self.settings.domain diff --git a/src/crewai/cli/config.py b/src/crewai/cli/config.py index f2a87792e..a8da400a8 100644 --- a/src/crewai/cli/config.py +++ b/src/crewai/cli/config.py @@ -4,7 +4,13 @@ from typing import Optional from pydantic import BaseModel, Field -from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL +from crewai.cli.constants import ( + DEFAULT_CREWAI_ENTERPRISE_URL, + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER, + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, +) DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json" @@ -19,11 +25,19 @@ USER_SETTINGS_KEYS = [ # Settings that are related to the CLI CLI_SETTINGS_KEYS = [ "enterprise_base_url", + "oauth2_provider", + "oauth2_audience", + "oauth2_client_id", + "oauth2_domain", ] # Default values for CLI settings DEFAULT_CLI_SETTINGS = { "enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL, + "oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER, + "oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, + "oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, + "oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, } # Readonly settings - cannot be set by the user @@ -39,10 +53,9 @@ HIDDEN_SETTINGS_KEYS = [ "tool_repository_password", ] - class Settings(BaseModel): enterprise_base_url: Optional[str] = Field( - default=DEFAULT_CREWAI_ENTERPRISE_URL, + default=DEFAULT_CLI_SETTINGS["enterprise_base_url"], description="Base URL of the CrewAI Enterprise instance", ) tool_repository_username: Optional[str] = Field( @@ -59,6 +72,26 @@ class Settings(BaseModel): ) config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True) + oauth2_provider: str = Field( + description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).", + default=DEFAULT_CLI_SETTINGS["oauth2_provider"] + ) + + oauth2_audience: Optional[str] = Field( + description="OAuth2 audience value, typically used to identify the target API or resource.", + default=DEFAULT_CLI_SETTINGS["oauth2_audience"] + ) + + oauth2_client_id: str = Field( + default=DEFAULT_CLI_SETTINGS["oauth2_client_id"], + description="OAuth2 client ID issued by the provider, used during authentication requests.", + ) + + oauth2_domain: str = Field( + description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.", + default=DEFAULT_CLI_SETTINGS["oauth2_domain"] + ) + def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data): """Load Settings from config path""" config_path.parent.mkdir(parents=True, exist_ok=True) @@ -105,4 +138,4 @@ class Settings(BaseModel): def _reset_cli_settings(self) -> None: """Reset all CLI settings to default values""" for key in CLI_SETTINGS_KEYS: - setattr(self, key, DEFAULT_CLI_SETTINGS[key]) + setattr(self, key, DEFAULT_CLI_SETTINGS.get(key)) diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index 06a02bee5..d0e867c41 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -1,4 +1,8 @@ DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com" +CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos" +CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8" +CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX" +CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com" ENV_VARS = { "openai": [ diff --git a/tests/cli/authentication/providers/__init__.py b/tests/cli/authentication/providers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/cli/authentication/providers/test_auth0.py b/tests/cli/authentication/providers/test_auth0.py new file mode 100644 index 000000000..e513a1fb7 --- /dev/null +++ b/tests/cli/authentication/providers/test_auth0.py @@ -0,0 +1,91 @@ +import pytest +from crewai.cli.authentication.main import Oauth2Settings +from crewai.cli.authentication.providers.auth0 import Auth0Provider + + + +class TestAuth0Provider: + + @pytest.fixture(autouse=True) + def setup_method(self): + self.valid_settings = Oauth2Settings( + provider="auth0", + domain="test-domain.auth0.com", + client_id="test-client-id", + audience="test-audience" + ) + self.provider = Auth0Provider(self.valid_settings) + + def test_initialization_with_valid_settings(self): + provider = Auth0Provider(self.valid_settings) + assert provider.settings == self.valid_settings + assert provider.settings.provider == "auth0" + assert provider.settings.domain == "test-domain.auth0.com" + assert provider.settings.client_id == "test-client-id" + assert provider.settings.audience == "test-audience" + + def test_get_authorize_url(self): + expected_url = "https://test-domain.auth0.com/oauth/device/code" + assert self.provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_different_domain(self): + settings = Oauth2Settings( + provider="auth0", + domain="my-company.auth0.com", + client_id="test-client", + audience="test-audience" + ) + provider = Auth0Provider(settings) + expected_url = "https://my-company.auth0.com/oauth/device/code" + assert provider.get_authorize_url() == expected_url + + def test_get_token_url(self): + expected_url = "https://test-domain.auth0.com/oauth/token" + assert self.provider.get_token_url() == expected_url + + def test_get_token_url_with_different_domain(self): + settings = Oauth2Settings( + provider="auth0", + domain="another-domain.auth0.com", + client_id="test-client", + audience="test-audience" + ) + provider = Auth0Provider(settings) + expected_url = "https://another-domain.auth0.com/oauth/token" + assert provider.get_token_url() == expected_url + + def test_get_jwks_url(self): + expected_url = "https://test-domain.auth0.com/.well-known/jwks.json" + assert self.provider.get_jwks_url() == expected_url + + def test_get_jwks_url_with_different_domain(self): + settings = Oauth2Settings( + provider="auth0", + domain="dev.auth0.com", + client_id="test-client", + audience="test-audience" + ) + provider = Auth0Provider(settings) + expected_url = "https://dev.auth0.com/.well-known/jwks.json" + assert provider.get_jwks_url() == expected_url + + def test_get_issuer(self): + expected_issuer = "https://test-domain.auth0.com/" + assert self.provider.get_issuer() == expected_issuer + + def test_get_issuer_with_different_domain(self): + settings = Oauth2Settings( + provider="auth0", + domain="prod.auth0.com", + client_id="test-client", + audience="test-audience" + ) + provider = Auth0Provider(settings) + expected_issuer = "https://prod.auth0.com/" + assert provider.get_issuer() == expected_issuer + + def test_get_audience(self): + assert self.provider.get_audience() == "test-audience" + + def test_get_client_id(self): + assert self.provider.get_client_id() == "test-client-id" diff --git a/tests/cli/authentication/providers/test_okta.py b/tests/cli/authentication/providers/test_okta.py new file mode 100644 index 000000000..b952464ba --- /dev/null +++ b/tests/cli/authentication/providers/test_okta.py @@ -0,0 +1,102 @@ +import pytest +from crewai.cli.authentication.main import Oauth2Settings +from crewai.cli.authentication.providers.okta import OktaProvider + + +class TestOktaProvider: + + @pytest.fixture(autouse=True) + def setup_method(self): + self.valid_settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience="test-audience" + ) + self.provider = OktaProvider(self.valid_settings) + + def test_initialization_with_valid_settings(self): + provider = OktaProvider(self.valid_settings) + assert provider.settings == self.valid_settings + assert provider.settings.provider == "okta" + assert provider.settings.domain == "test-domain.okta.com" + assert provider.settings.client_id == "test-client-id" + assert provider.settings.audience == "test-audience" + + def test_get_authorize_url(self): + expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize" + assert self.provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_different_domain(self): + settings = Oauth2Settings( + provider="okta", + domain="my-company.okta.com", + client_id="test-client", + audience="test-audience" + ) + provider = OktaProvider(settings) + expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize" + assert provider.get_authorize_url() == expected_url + + def test_get_token_url(self): + expected_url = "https://test-domain.okta.com/oauth2/default/v1/token" + assert self.provider.get_token_url() == expected_url + + def test_get_token_url_with_different_domain(self): + settings = Oauth2Settings( + provider="okta", + domain="another-domain.okta.com", + client_id="test-client", + audience="test-audience" + ) + provider = OktaProvider(settings) + expected_url = "https://another-domain.okta.com/oauth2/default/v1/token" + assert provider.get_token_url() == expected_url + + def test_get_jwks_url(self): + expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys" + assert self.provider.get_jwks_url() == expected_url + + def test_get_jwks_url_with_different_domain(self): + settings = Oauth2Settings( + provider="okta", + domain="dev.okta.com", + client_id="test-client", + audience="test-audience" + ) + provider = OktaProvider(settings) + expected_url = "https://dev.okta.com/oauth2/default/v1/keys" + assert provider.get_jwks_url() == expected_url + + def test_get_issuer(self): + expected_issuer = "https://test-domain.okta.com/oauth2/default" + assert self.provider.get_issuer() == expected_issuer + + def test_get_issuer_with_different_domain(self): + settings = Oauth2Settings( + provider="okta", + domain="prod.okta.com", + client_id="test-client", + audience="test-audience" + ) + provider = OktaProvider(settings) + expected_issuer = "https://prod.okta.com/oauth2/default" + assert provider.get_issuer() == expected_issuer + + def test_get_audience(self): + assert self.provider.get_audience() == "test-audience" + + def test_get_audience_assertion_error_when_none(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None + ) + provider = OktaProvider(settings) + + with pytest.raises(AssertionError): + provider.get_audience() + + def test_get_client_id(self): + assert self.provider.get_client_id() == "test-client-id" diff --git a/tests/cli/authentication/providers/test_workos.py b/tests/cli/authentication/providers/test_workos.py new file mode 100644 index 000000000..7eda774d6 --- /dev/null +++ b/tests/cli/authentication/providers/test_workos.py @@ -0,0 +1,100 @@ +import pytest +from crewai.cli.authentication.main import Oauth2Settings +from crewai.cli.authentication.providers.workos import WorkosProvider + + +class TestWorkosProvider: + + @pytest.fixture(autouse=True) + def setup_method(self): + self.valid_settings = Oauth2Settings( + provider="workos", + domain="login.company.com", + client_id="test-client-id", + audience="test-audience" + ) + self.provider = WorkosProvider(self.valid_settings) + + def test_initialization_with_valid_settings(self): + provider = WorkosProvider(self.valid_settings) + assert provider.settings == self.valid_settings + assert provider.settings.provider == "workos" + assert provider.settings.domain == "login.company.com" + assert provider.settings.client_id == "test-client-id" + assert provider.settings.audience == "test-audience" + + def test_get_authorize_url(self): + expected_url = "https://login.company.com/oauth2/device_authorization" + assert self.provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_different_domain(self): + settings = Oauth2Settings( + provider="workos", + domain="login.example.com", + client_id="test-client", + audience="test-audience" + ) + provider = WorkosProvider(settings) + expected_url = "https://login.example.com/oauth2/device_authorization" + assert provider.get_authorize_url() == expected_url + + def test_get_token_url(self): + expected_url = "https://login.company.com/oauth2/token" + assert self.provider.get_token_url() == expected_url + + def test_get_token_url_with_different_domain(self): + settings = Oauth2Settings( + provider="workos", + domain="api.workos.com", + client_id="test-client", + audience="test-audience" + ) + provider = WorkosProvider(settings) + expected_url = "https://api.workos.com/oauth2/token" + assert provider.get_token_url() == expected_url + + def test_get_jwks_url(self): + expected_url = "https://login.company.com/oauth2/jwks" + assert self.provider.get_jwks_url() == expected_url + + def test_get_jwks_url_with_different_domain(self): + settings = Oauth2Settings( + provider="workos", + domain="auth.enterprise.com", + client_id="test-client", + audience="test-audience" + ) + provider = WorkosProvider(settings) + expected_url = "https://auth.enterprise.com/oauth2/jwks" + assert provider.get_jwks_url() == expected_url + + def test_get_issuer(self): + expected_issuer = "https://login.company.com" + assert self.provider.get_issuer() == expected_issuer + + def test_get_issuer_with_different_domain(self): + settings = Oauth2Settings( + provider="workos", + domain="sso.company.com", + client_id="test-client", + audience="test-audience" + ) + provider = WorkosProvider(settings) + expected_issuer = "https://sso.company.com" + assert provider.get_issuer() == expected_issuer + + def test_get_audience(self): + assert self.provider.get_audience() == "test-audience" + + def test_get_audience_fallback_to_default(self): + settings = Oauth2Settings( + provider="workos", + domain="login.company.com", + client_id="test-client-id", + audience=None + ) + provider = WorkosProvider(settings) + assert provider.get_audience() == "" + + def test_get_client_id(self): + assert self.provider.get_client_id() == "test-client-id" diff --git a/tests/cli/authentication/test_auth_main.py b/tests/cli/authentication/test_auth_main.py index 61511b5a1..d608c9ba4 100644 --- a/tests/cli/authentication/test_auth_main.py +++ b/tests/cli/authentication/test_auth_main.py @@ -6,10 +6,12 @@ from crewai.cli.authentication.main import AuthenticationCommand from crewai.cli.authentication.constants import ( AUTH0_AUDIENCE, AUTH0_CLIENT_ID, - AUTH0_DOMAIN, - WORKOS_DOMAIN, - WORKOS_CLI_CONNECT_APP_ID, - WORKOS_ENVIRONMENT_ID, + AUTH0_DOMAIN +) +from crewai.cli.constants import ( + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, ) @@ -27,14 +29,17 @@ class TestAuthenticationCommand: "token_url": f"https://{AUTH0_DOMAIN}/oauth/token", "client_id": AUTH0_CLIENT_ID, "audience": AUTH0_AUDIENCE, + "domain": AUTH0_DOMAIN, }, ), ( "workos", { - "device_code_url": f"https://{WORKOS_DOMAIN}/oauth2/device_authorization", - "token_url": f"https://{WORKOS_DOMAIN}/oauth2/token", - "client_id": WORKOS_CLI_CONNECT_APP_ID, + "device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization", + "token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token", + "client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, + "audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, + "domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, }, ), ], @@ -70,19 +75,16 @@ class TestAuthenticationCommand: "Signing in to CrewAI Enterprise...\n", style="bold blue" ) mock_determine_provider.assert_called_once() - mock_get_device.assert_called_once_with( - expected_urls["client_id"], - expected_urls["device_code_url"], - expected_urls.get("audience", None), - ) + mock_get_device.assert_called_once() mock_display.assert_called_once_with( {"device_code": "test_code", "user_code": "123456"} ) mock_poll.assert_called_once_with( {"device_code": "test_code", "user_code": "123456"}, - expected_urls["client_id"], - expected_urls["token_url"], ) + assert self.auth_command.oauth2_provider.get_client_id() == expected_urls["client_id"] + assert self.auth_command.oauth2_provider.get_audience() == expected_urls["audience"] + assert self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"] @patch("crewai.cli.authentication.main.webbrowser") @patch("crewai.cli.authentication.main.console.print") @@ -115,9 +117,9 @@ class TestAuthenticationCommand: ( "workos", { - "jwks_url": f"https://{WORKOS_DOMAIN}/oauth2/jwks", - "issuer": f"https://{WORKOS_DOMAIN}", - "audience": WORKOS_ENVIRONMENT_ID, + "jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks", + "issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}", + "audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, }, ), ], @@ -133,7 +135,15 @@ class TestAuthenticationCommand: jwt_config, has_expiration, ): - self.auth_command.user_provider = user_provider + from crewai.cli.authentication.providers.auth0 import Auth0Provider + from crewai.cli.authentication.providers.workos import WorkosProvider + from crewai.cli.authentication.main import Oauth2Settings + + if user_provider == "auth0": + self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=AUTH0_DOMAIN, audience=jwt_config["audience"])) + elif user_provider == "workos": + self.auth_command.oauth2_provider = WorkosProvider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, audience=jwt_config["audience"])) + token_data = {"access_token": "test_access_token", "id_token": "test_id_token"} if has_expiration: @@ -311,11 +321,12 @@ class TestAuthenticationCommand: } mock_post.return_value = mock_response - result = self.auth_command._get_device_code( - client_id="test_client", - device_code_url="https://example.com/device", - audience="test_audience", - ) + self.auth_command.oauth2_provider = MagicMock() + self.auth_command.oauth2_provider.get_client_id.return_value = "test_client" + self.auth_command.oauth2_provider.get_authorize_url.return_value = "https://example.com/device" + self.auth_command.oauth2_provider.get_audience.return_value = "test_audience" + + result = self.auth_command._get_device_code() mock_post.assert_called_once_with( url="https://example.com/device", @@ -354,8 +365,12 @@ class TestAuthenticationCommand: self.auth_command, "_login_to_tool_repository" ) as mock_tool_login, ): + self.auth_command.oauth2_provider = MagicMock() + self.auth_command.oauth2_provider.get_token_url.return_value = "https://example.com/token" + self.auth_command.oauth2_provider.get_client_id.return_value = "test_client" + self.auth_command._poll_for_token( - device_code_data, "test_client", "https://example.com/token" + device_code_data ) mock_post.assert_called_once_with( @@ -392,7 +407,7 @@ class TestAuthenticationCommand: } self.auth_command._poll_for_token( - device_code_data, "test_client", "https://example.com/token" + device_code_data ) mock_console_print.assert_any_call( @@ -415,5 +430,14 @@ class TestAuthenticationCommand: with pytest.raises(requests.HTTPError): self.auth_command._poll_for_token( - device_code_data, "test_client", "https://example.com/token" + device_code_data ) + # @patch( + # "crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider" + # ) + # def test_login_with_auth0(self, mock_determine_provider): + # from crewai.cli.authentication.providers.auth0 import Auth0Provider + # from crewai.cli.authentication.main import Oauth2Settings + + # self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider="auth0", client_id=AUTH0_CLIENT_ID, domain=AUTH0_DOMAIN, audience=AUTH0_AUDIENCE)) + # self.auth_command.login() diff --git a/tests/cli/config_test.py b/tests/cli/config_test.py index 06cbfcf2c..a492da54a 100644 --- a/tests/cli/config_test.py +++ b/tests/cli/config_test.py @@ -79,7 +79,7 @@ class TestSettings(unittest.TestCase): for key in user_settings.keys(): self.assertEqual(getattr(settings, key), None) for key in cli_settings.keys(): - self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS[key]) + self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key)) def test_dump_new_settings(self): settings = Settings( diff --git a/tests/cli/test_settings_command.py b/tests/cli/test_settings_command.py index 71d016a52..f15deb821 100644 --- a/tests/cli/test_settings_command.py +++ b/tests/cli/test_settings_command.py @@ -81,11 +81,10 @@ class TestSettingsCommand(unittest.TestCase): self.settings_command.reset_all_settings() - print(USER_SETTINGS_KEYS) for key in USER_SETTINGS_KEYS: self.assertEqual(getattr(self.settings_command.settings, key), None) for key in CLI_SETTINGS_KEYS: self.assertEqual( - getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS[key] + getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS.get(key) )