Support Device authorization with Okta (#3271)

* feat: support oauth2 config for authentication

* refactor: improve OAuth2 settings management

The CLI now supports seamless integration with other authentication providers, since the client_id, issue, domain are now manage by the user

* feat: support okta Device Authorization flow

* chore: resolve linter issues

* test: fix tests

* test: adding tests for auth providers

* test: fix broken test

* refator: adding WorkOS paramenters as default settings auth

* chore: improve oauth2 attributes description

* refactor: simplify WorkOS getting values

* fix: ensure Auth0 parameters is set when overrinding default auth provider

* chore: remove TODO Auth0 no longer provides default values

---------

Co-authored-by: Heitor Carvalho <heitor.scz@gmail.com>
This commit is contained in:
Lucas Gomide
2025-08-05 13:16:21 -03:00
committed by GitHub
parent 0b31bbe957
commit 66567bdc2f
15 changed files with 548 additions and 83 deletions

View File

@@ -1,8 +1,6 @@
ALGORITHMS = ["RS256"] ALGORITHMS = ["RS256"]
#TODO: The AUTH0 constants should be removed after WorkOS migration is completed
AUTH0_DOMAIN = "crewai.us.auth0.com" AUTH0_DOMAIN = "crewai.us.auth0.com"
AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8" AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8"
AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/" AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/"
WORKOS_DOMAIN = "login.crewai.com"
WORKOS_CLI_CONNECT_APP_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
WORKOS_ENVIRONMENT_ID = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"

View File

@@ -1,76 +1,92 @@
import time import time
import webbrowser import webbrowser
from typing import Any, Dict from typing import Any, Dict, Optional
import requests import requests
from rich.console import Console from rich.console import Console
from pydantic import BaseModel, Field
from .constants import (
AUTH0_AUDIENCE,
AUTH0_CLIENT_ID,
AUTH0_DOMAIN,
WORKOS_DOMAIN,
WORKOS_CLI_CONNECT_APP_ID,
WORKOS_ENVIRONMENT_ID,
)
from .utils import TokenManager, validate_jwt_token from .utils import TokenManager, validate_jwt_token
from urllib.parse import quote from urllib.parse import quote
from crewai.cli.plus_api import PlusAPI from crewai.cli.plus_api import PlusAPI
from crewai.cli.config import Settings from crewai.cli.config import Settings
from crewai.cli.authentication.constants import (
AUTH0_AUDIENCE,
AUTH0_CLIENT_ID,
AUTH0_DOMAIN,
)
console = Console() console = Console()
class Oauth2Settings(BaseModel):
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
@classmethod
def from_settings(cls):
settings = Settings()
return cls(
provider=settings.oauth2_provider,
domain=settings.oauth2_domain,
client_id=settings.oauth2_client_id,
audience=settings.oauth2_audience,
)
class ProviderFactory:
@classmethod
def from_settings(cls, settings: Optional[Oauth2Settings] = None):
settings = settings or Oauth2Settings.from_settings()
import importlib
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
return provider(settings)
class AuthenticationCommand: class AuthenticationCommand:
AUTH0_DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
AUTH0_TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
WORKOS_DEVICE_CODE_URL = f"https://{WORKOS_DOMAIN}/oauth2/device_authorization"
WORKOS_TOKEN_URL = f"https://{WORKOS_DOMAIN}/oauth2/token"
def __init__(self): def __init__(self):
self.token_manager = TokenManager() self.token_manager = TokenManager()
# TODO: WORKOS - This variable is temporary until migration to WorkOS is complete. self.oauth2_provider = ProviderFactory.from_settings()
self.user_provider = "workos"
def login(self) -> None: def login(self) -> None:
"""Sign up to CrewAI+""" """Sign up to CrewAI+"""
device_code_url = self.WORKOS_DEVICE_CODE_URL
token_url = self.WORKOS_TOKEN_URL
client_id = WORKOS_CLI_CONNECT_APP_ID
audience = None
console.print("Signing in to CrewAI Enterprise...\n", style="bold blue") console.print("Signing in to CrewAI Enterprise...\n", style="bold blue")
# TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete. # TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete.
user_provider = self._determine_user_provider() user_provider = self._determine_user_provider()
if user_provider == "auth0": if user_provider == "auth0":
device_code_url = self.AUTH0_DEVICE_CODE_URL settings = Oauth2Settings(
token_url = self.AUTH0_TOKEN_URL provider="auth0",
client_id = AUTH0_CLIENT_ID client_id=AUTH0_CLIENT_ID,
audience = AUTH0_AUDIENCE domain=AUTH0_DOMAIN,
self.user_provider = "auth0" audience=AUTH0_AUDIENCE
)
self.oauth2_provider = ProviderFactory.from_settings(settings)
# End of temporary code. # End of temporary code.
device_code_data = self._get_device_code(client_id, device_code_url, audience) device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data) self._display_auth_instructions(device_code_data)
return self._poll_for_token(device_code_data, client_id, token_url) return self._poll_for_token(device_code_data)
def _get_device_code( def _get_device_code(
self, client_id: str, device_code_url: str, audience: str | None = None self
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Get the device code to authenticate the user.""" """Get the device code to authenticate the user."""
device_code_payload = { device_code_payload = {
"client_id": client_id, "client_id": self.oauth2_provider.get_client_id(),
"scope": "openid", "scope": "openid",
"audience": audience, "audience": self.oauth2_provider.get_audience(),
} }
response = requests.post( response = requests.post(
url=device_code_url, data=device_code_payload, timeout=20 url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -82,21 +98,21 @@ class AuthenticationCommand:
webbrowser.open(device_code_data["verification_uri_complete"]) webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token( def _poll_for_token(
self, device_code_data: Dict[str, Any], client_id: str, token_poll_url: str self, device_code_data: Dict[str, Any]
) -> None: ) -> None:
"""Polls the server for the token until it is received, or max attempts are reached.""" """Polls the server for the token until it is received, or max attempts are reached."""
token_payload = { token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code", "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code_data["device_code"], "device_code": device_code_data["device_code"],
"client_id": client_id, "client_id": self.oauth2_provider.get_client_id(),
} }
console.print("\nWaiting for authentication... ", style="bold blue", end="") console.print("\nWaiting for authentication... ", style="bold blue", end="")
attempts = 0 attempts = 0
while True and attempts < 10: while True and attempts < 10:
response = requests.post(token_poll_url, data=token_payload, timeout=30) response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
token_data = response.json() token_data = response.json()
if response.status_code == 200: if response.status_code == 200:
@@ -128,19 +144,14 @@ class AuthenticationCommand:
"""Validates the JWT token and saves the token to the token manager.""" """Validates the JWT token and saves the token to the token manager."""
jwt_token = token_data["access_token"] jwt_token = token_data["access_token"]
issuer = self.oauth2_provider.get_issuer()
jwt_token_data = { jwt_token_data = {
"jwt_token": jwt_token, "jwt_token": jwt_token,
"jwks_url": f"https://{WORKOS_DOMAIN}/oauth2/jwks", "jwks_url": self.oauth2_provider.get_jwks_url(),
"issuer": f"https://{WORKOS_DOMAIN}", "issuer": issuer,
"audience": WORKOS_ENVIRONMENT_ID, "audience": self.oauth2_provider.get_audience(),
} }
# TODO: WORKOS - The following conditional is temporary until migration to WorkOS is complete.
if self.user_provider == "auth0":
jwt_token_data["jwks_url"] = f"https://{AUTH0_DOMAIN}/.well-known/jwks.json"
jwt_token_data["issuer"] = f"https://{AUTH0_DOMAIN}/"
jwt_token_data["audience"] = AUTH0_AUDIENCE
decoded_token = validate_jwt_token(**jwt_token_data) decoded_token = validate_jwt_token(**jwt_token_data)
expires_at = decoded_token.get("exp", 0) expires_at = decoded_token.get("exp", 0)

View File

@@ -0,0 +1,26 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth/device/code"
def get_token_url(self) -> str:
return f"https://{self._get_domain()}/oauth/token"
def get_jwks_url(self) -> str:
return f"https://{self._get_domain()}/.well-known/jwks.json"
def get_issuer(self) -> str:
return f"https://{self._get_domain()}/"
def get_audience(self) -> str:
assert self.settings.audience is not None, "Audience is required"
return self.settings.audience
def get_client_id(self) -> str:
assert self.settings.client_id is not None, "Client ID is required"
return self.settings.client_id
def _get_domain(self) -> str:
assert self.settings.domain is not None, "Domain is required"
return self.settings.domain

View File

@@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from crewai.cli.authentication.main import Oauth2Settings
class BaseProvider(ABC):
def __init__(self, settings: Oauth2Settings):
self.settings = settings
@abstractmethod
def get_authorize_url(self) -> str:
...
@abstractmethod
def get_token_url(self) -> str:
...
@abstractmethod
def get_jwks_url(self) -> str:
...
@abstractmethod
def get_issuer(self) -> str:
...
@abstractmethod
def get_audience(self) -> str:
...
@abstractmethod
def get_client_id(self) -> str:
...

View File

@@ -0,0 +1,22 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
def get_token_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/token"
def get_jwks_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/keys"
def get_issuer(self) -> str:
return f"https://{self.settings.domain}/oauth2/default"
def get_audience(self) -> str:
assert self.settings.audience is not None
return self.settings.audience
def get_client_id(self) -> str:
assert self.settings.client_id is not None
return self.settings.client_id

View File

@@ -0,0 +1,25 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/device_authorization"
def get_token_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/token"
def get_jwks_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/jwks"
def get_issuer(self) -> str:
return f"https://{self._get_domain()}"
def get_audience(self) -> str:
return self.settings.audience or ""
def get_client_id(self) -> str:
assert self.settings.client_id is not None, "Client ID is required"
return self.settings.client_id
def _get_domain(self) -> str:
assert self.settings.domain is not None, "Domain is required"
return self.settings.domain

View File

@@ -4,7 +4,13 @@ from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL from crewai.cli.constants import (
DEFAULT_CREWAI_ENTERPRISE_URL,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
)
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json" DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
@@ -19,11 +25,19 @@ USER_SETTINGS_KEYS = [
# Settings that are related to the CLI # Settings that are related to the CLI
CLI_SETTINGS_KEYS = [ CLI_SETTINGS_KEYS = [
"enterprise_base_url", "enterprise_base_url",
"oauth2_provider",
"oauth2_audience",
"oauth2_client_id",
"oauth2_domain",
] ]
# Default values for CLI settings # Default values for CLI settings
DEFAULT_CLI_SETTINGS = { DEFAULT_CLI_SETTINGS = {
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL, "enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
} }
# Readonly settings - cannot be set by the user # Readonly settings - cannot be set by the user
@@ -39,10 +53,9 @@ HIDDEN_SETTINGS_KEYS = [
"tool_repository_password", "tool_repository_password",
] ]
class Settings(BaseModel): class Settings(BaseModel):
enterprise_base_url: Optional[str] = Field( enterprise_base_url: Optional[str] = Field(
default=DEFAULT_CREWAI_ENTERPRISE_URL, default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
description="Base URL of the CrewAI Enterprise instance", description="Base URL of the CrewAI Enterprise instance",
) )
tool_repository_username: Optional[str] = Field( tool_repository_username: Optional[str] = Field(
@@ -59,6 +72,26 @@ class Settings(BaseModel):
) )
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True) config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
oauth2_provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
default=DEFAULT_CLI_SETTINGS["oauth2_provider"]
)
oauth2_audience: Optional[str] = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=DEFAULT_CLI_SETTINGS["oauth2_audience"]
)
oauth2_client_id: str = Field(
default=DEFAULT_CLI_SETTINGS["oauth2_client_id"],
description="OAuth2 client ID issued by the provider, used during authentication requests.",
)
oauth2_domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
default=DEFAULT_CLI_SETTINGS["oauth2_domain"]
)
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data): def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
"""Load Settings from config path""" """Load Settings from config path"""
config_path.parent.mkdir(parents=True, exist_ok=True) config_path.parent.mkdir(parents=True, exist_ok=True)
@@ -105,4 +138,4 @@ class Settings(BaseModel):
def _reset_cli_settings(self) -> None: def _reset_cli_settings(self) -> None:
"""Reset all CLI settings to default values""" """Reset all CLI settings to default values"""
for key in CLI_SETTINGS_KEYS: for key in CLI_SETTINGS_KEYS:
setattr(self, key, DEFAULT_CLI_SETTINGS[key]) setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))

View File

@@ -1,4 +1,8 @@
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com" DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"
ENV_VARS = { ENV_VARS = {
"openai": [ "openai": [

View File

@@ -0,0 +1,91 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.auth0 import Auth0Provider
class TestAuth0Provider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="auth0",
domain="test-domain.auth0.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = Auth0Provider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = Auth0Provider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "auth0"
assert provider.settings.domain == "test-domain.auth0.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.auth0.com/oauth/device/code"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="my-company.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://my-company.auth0.com/oauth/device/code"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.auth0.com/oauth/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="another-domain.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://another-domain.auth0.com/oauth/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.auth0.com/.well-known/jwks.json"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="dev.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://dev.auth0.com/.well-known/jwks.json"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.auth0.com/"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="prod.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_issuer = "https://prod.auth0.com/"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -0,0 +1,102 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.okta import OktaProvider
class TestOktaProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = OktaProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = OktaProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "okta"
assert provider.settings.domain == "test-domain.okta.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="my-company.okta.com",
client_id="test-client",
audience="test-audience"
)
provider = OktaProvider(settings)
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="another-domain.okta.com",
client_id="test-client",
audience="test-audience"
)
provider = OktaProvider(settings)
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="dev.okta.com",
client_id="test-client",
audience="test-audience"
)
provider = OktaProvider(settings)
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.okta.com/oauth2/default"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="prod.okta.com",
client_id="test-client",
audience="test-audience"
)
provider = OktaProvider(settings)
expected_issuer = "https://prod.okta.com/oauth2/default"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None
)
provider = OktaProvider(settings)
with pytest.raises(AssertionError):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -0,0 +1,100 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.workos import WorkosProvider
class TestWorkosProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = WorkosProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = WorkosProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "workos"
assert provider.settings.domain == "login.company.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.company.com/oauth2/device_authorization"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="login.example.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://login.example.com/oauth2/device_authorization"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.company.com/oauth2/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="api.workos.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://api.workos.com/oauth2/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.company.com/oauth2/jwks"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="auth.enterprise.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://auth.enterprise.com/oauth2/jwks"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.company.com"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="sso.company.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_issuer = "https://sso.company.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_fallback_to_default(self):
settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience=None
)
provider = WorkosProvider(settings)
assert provider.get_audience() == ""
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -6,10 +6,12 @@ from crewai.cli.authentication.main import AuthenticationCommand
from crewai.cli.authentication.constants import ( from crewai.cli.authentication.constants import (
AUTH0_AUDIENCE, AUTH0_AUDIENCE,
AUTH0_CLIENT_ID, AUTH0_CLIENT_ID,
AUTH0_DOMAIN, AUTH0_DOMAIN
WORKOS_DOMAIN, )
WORKOS_CLI_CONNECT_APP_ID, from crewai.cli.constants import (
WORKOS_ENVIRONMENT_ID, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
) )
@@ -27,14 +29,17 @@ class TestAuthenticationCommand:
"token_url": f"https://{AUTH0_DOMAIN}/oauth/token", "token_url": f"https://{AUTH0_DOMAIN}/oauth/token",
"client_id": AUTH0_CLIENT_ID, "client_id": AUTH0_CLIENT_ID,
"audience": AUTH0_AUDIENCE, "audience": AUTH0_AUDIENCE,
"domain": AUTH0_DOMAIN,
}, },
), ),
( (
"workos", "workos",
{ {
"device_code_url": f"https://{WORKOS_DOMAIN}/oauth2/device_authorization", "device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization",
"token_url": f"https://{WORKOS_DOMAIN}/oauth2/token", "token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token",
"client_id": WORKOS_CLI_CONNECT_APP_ID, "client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
"domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
}, },
), ),
], ],
@@ -70,19 +75,16 @@ class TestAuthenticationCommand:
"Signing in to CrewAI Enterprise...\n", style="bold blue" "Signing in to CrewAI Enterprise...\n", style="bold blue"
) )
mock_determine_provider.assert_called_once() mock_determine_provider.assert_called_once()
mock_get_device.assert_called_once_with( mock_get_device.assert_called_once()
expected_urls["client_id"],
expected_urls["device_code_url"],
expected_urls.get("audience", None),
)
mock_display.assert_called_once_with( mock_display.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"} {"device_code": "test_code", "user_code": "123456"}
) )
mock_poll.assert_called_once_with( mock_poll.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"}, {"device_code": "test_code", "user_code": "123456"},
expected_urls["client_id"],
expected_urls["token_url"],
) )
assert self.auth_command.oauth2_provider.get_client_id() == expected_urls["client_id"]
assert self.auth_command.oauth2_provider.get_audience() == expected_urls["audience"]
assert self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
@patch("crewai.cli.authentication.main.webbrowser") @patch("crewai.cli.authentication.main.webbrowser")
@patch("crewai.cli.authentication.main.console.print") @patch("crewai.cli.authentication.main.console.print")
@@ -115,9 +117,9 @@ class TestAuthenticationCommand:
( (
"workos", "workos",
{ {
"jwks_url": f"https://{WORKOS_DOMAIN}/oauth2/jwks", "jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks",
"issuer": f"https://{WORKOS_DOMAIN}", "issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}",
"audience": WORKOS_ENVIRONMENT_ID, "audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
}, },
), ),
], ],
@@ -133,7 +135,15 @@ class TestAuthenticationCommand:
jwt_config, jwt_config,
has_expiration, has_expiration,
): ):
self.auth_command.user_provider = user_provider from crewai.cli.authentication.providers.auth0 import Auth0Provider
from crewai.cli.authentication.providers.workos import WorkosProvider
from crewai.cli.authentication.main import Oauth2Settings
if user_provider == "auth0":
self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=AUTH0_DOMAIN, audience=jwt_config["audience"]))
elif user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, audience=jwt_config["audience"]))
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"} token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
if has_expiration: if has_expiration:
@@ -311,11 +321,12 @@ class TestAuthenticationCommand:
} }
mock_post.return_value = mock_response mock_post.return_value = mock_response
result = self.auth_command._get_device_code( self.auth_command.oauth2_provider = MagicMock()
client_id="test_client", self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
device_code_url="https://example.com/device", self.auth_command.oauth2_provider.get_authorize_url.return_value = "https://example.com/device"
audience="test_audience", self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
)
result = self.auth_command._get_device_code()
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
url="https://example.com/device", url="https://example.com/device",
@@ -354,8 +365,12 @@ class TestAuthenticationCommand:
self.auth_command, "_login_to_tool_repository" self.auth_command, "_login_to_tool_repository"
) as mock_tool_login, ) as mock_tool_login,
): ):
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_token_url.return_value = "https://example.com/token"
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command._poll_for_token( self.auth_command._poll_for_token(
device_code_data, "test_client", "https://example.com/token" device_code_data
) )
mock_post.assert_called_once_with( mock_post.assert_called_once_with(
@@ -392,7 +407,7 @@ class TestAuthenticationCommand:
} }
self.auth_command._poll_for_token( self.auth_command._poll_for_token(
device_code_data, "test_client", "https://example.com/token" device_code_data
) )
mock_console_print.assert_any_call( mock_console_print.assert_any_call(
@@ -415,5 +430,14 @@ class TestAuthenticationCommand:
with pytest.raises(requests.HTTPError): with pytest.raises(requests.HTTPError):
self.auth_command._poll_for_token( self.auth_command._poll_for_token(
device_code_data, "test_client", "https://example.com/token" device_code_data
) )
# @patch(
# "crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
# )
# def test_login_with_auth0(self, mock_determine_provider):
# from crewai.cli.authentication.providers.auth0 import Auth0Provider
# from crewai.cli.authentication.main import Oauth2Settings
# self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider="auth0", client_id=AUTH0_CLIENT_ID, domain=AUTH0_DOMAIN, audience=AUTH0_AUDIENCE))
# self.auth_command.login()

View File

@@ -79,7 +79,7 @@ class TestSettings(unittest.TestCase):
for key in user_settings.keys(): for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None) self.assertEqual(getattr(settings, key), None)
for key in cli_settings.keys(): for key in cli_settings.keys():
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS[key]) self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
def test_dump_new_settings(self): def test_dump_new_settings(self):
settings = Settings( settings = Settings(

View File

@@ -81,11 +81,10 @@ class TestSettingsCommand(unittest.TestCase):
self.settings_command.reset_all_settings() self.settings_command.reset_all_settings()
print(USER_SETTINGS_KEYS)
for key in USER_SETTINGS_KEYS: for key in USER_SETTINGS_KEYS:
self.assertEqual(getattr(self.settings_command.settings, key), None) self.assertEqual(getattr(self.settings_command.settings, key), None)
for key in CLI_SETTINGS_KEYS: for key in CLI_SETTINGS_KEYS:
self.assertEqual( self.assertEqual(
getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS[key] getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS.get(key)
) )