Support Device authorization with Okta (#3271)

* feat: support oauth2 config for authentication

* refactor: improve OAuth2 settings management

The CLI now supports seamless integration with other authentication providers, since the client_id, issue, domain are now manage by the user

* feat: support okta Device Authorization flow

* chore: resolve linter issues

* test: fix tests

* test: adding tests for auth providers

* test: fix broken test

* refator: adding WorkOS paramenters as default settings auth

* chore: improve oauth2 attributes description

* refactor: simplify WorkOS getting values

* fix: ensure Auth0 parameters is set when overrinding default auth provider

* chore: remove TODO Auth0 no longer provides default values

---------

Co-authored-by: Heitor Carvalho <heitor.scz@gmail.com>
This commit is contained in:
Lucas Gomide
2025-08-05 13:16:21 -03:00
committed by GitHub
parent 0b31bbe957
commit 66567bdc2f
15 changed files with 548 additions and 83 deletions

View File

@@ -1,8 +1,6 @@
ALGORITHMS = ["RS256"]
#TODO: The AUTH0 constants should be removed after WorkOS migration is completed
AUTH0_DOMAIN = "crewai.us.auth0.com"
AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8"
AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/"
WORKOS_DOMAIN = "login.crewai.com"
WORKOS_CLI_CONNECT_APP_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
WORKOS_ENVIRONMENT_ID = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"

View File

@@ -1,76 +1,92 @@
import time
import webbrowser
from typing import Any, Dict
from typing import Any, Dict, Optional
import requests
from rich.console import Console
from pydantic import BaseModel, Field
from .constants import (
AUTH0_AUDIENCE,
AUTH0_CLIENT_ID,
AUTH0_DOMAIN,
WORKOS_DOMAIN,
WORKOS_CLI_CONNECT_APP_ID,
WORKOS_ENVIRONMENT_ID,
)
from .utils import TokenManager, validate_jwt_token
from urllib.parse import quote
from crewai.cli.plus_api import PlusAPI
from crewai.cli.config import Settings
from crewai.cli.authentication.constants import (
AUTH0_AUDIENCE,
AUTH0_CLIENT_ID,
AUTH0_DOMAIN,
)
console = Console()
class Oauth2Settings(BaseModel):
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
@classmethod
def from_settings(cls):
settings = Settings()
return cls(
provider=settings.oauth2_provider,
domain=settings.oauth2_domain,
client_id=settings.oauth2_client_id,
audience=settings.oauth2_audience,
)
class ProviderFactory:
@classmethod
def from_settings(cls, settings: Optional[Oauth2Settings] = None):
settings = settings or Oauth2Settings.from_settings()
import importlib
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
return provider(settings)
class AuthenticationCommand:
AUTH0_DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
AUTH0_TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
WORKOS_DEVICE_CODE_URL = f"https://{WORKOS_DOMAIN}/oauth2/device_authorization"
WORKOS_TOKEN_URL = f"https://{WORKOS_DOMAIN}/oauth2/token"
def __init__(self):
self.token_manager = TokenManager()
# TODO: WORKOS - This variable is temporary until migration to WorkOS is complete.
self.user_provider = "workos"
self.oauth2_provider = ProviderFactory.from_settings()
def login(self) -> None:
"""Sign up to CrewAI+"""
device_code_url = self.WORKOS_DEVICE_CODE_URL
token_url = self.WORKOS_TOKEN_URL
client_id = WORKOS_CLI_CONNECT_APP_ID
audience = None
console.print("Signing in to CrewAI Enterprise...\n", style="bold blue")
# TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete.
user_provider = self._determine_user_provider()
if user_provider == "auth0":
device_code_url = self.AUTH0_DEVICE_CODE_URL
token_url = self.AUTH0_TOKEN_URL
client_id = AUTH0_CLIENT_ID
audience = AUTH0_AUDIENCE
self.user_provider = "auth0"
settings = Oauth2Settings(
provider="auth0",
client_id=AUTH0_CLIENT_ID,
domain=AUTH0_DOMAIN,
audience=AUTH0_AUDIENCE
)
self.oauth2_provider = ProviderFactory.from_settings(settings)
# End of temporary code.
device_code_data = self._get_device_code(client_id, device_code_url, audience)
device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data)
return self._poll_for_token(device_code_data, client_id, token_url)
return self._poll_for_token(device_code_data)
def _get_device_code(
self, client_id: str, device_code_url: str, audience: str | None = None
self
) -> Dict[str, Any]:
"""Get the device code to authenticate the user."""
device_code_payload = {
"client_id": client_id,
"client_id": self.oauth2_provider.get_client_id(),
"scope": "openid",
"audience": audience,
"audience": self.oauth2_provider.get_audience(),
}
response = requests.post(
url=device_code_url, data=device_code_payload, timeout=20
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
)
response.raise_for_status()
return response.json()
@@ -82,21 +98,21 @@ class AuthenticationCommand:
webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token(
self, device_code_data: Dict[str, Any], client_id: str, token_poll_url: str
self, device_code_data: Dict[str, Any]
) -> None:
"""Polls the server for the token until it is received, or max attempts are reached."""
token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code_data["device_code"],
"client_id": client_id,
"client_id": self.oauth2_provider.get_client_id(),
}
console.print("\nWaiting for authentication... ", style="bold blue", end="")
attempts = 0
while True and attempts < 10:
response = requests.post(token_poll_url, data=token_payload, timeout=30)
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
token_data = response.json()
if response.status_code == 200:
@@ -128,19 +144,14 @@ class AuthenticationCommand:
"""Validates the JWT token and saves the token to the token manager."""
jwt_token = token_data["access_token"]
issuer = self.oauth2_provider.get_issuer()
jwt_token_data = {
"jwt_token": jwt_token,
"jwks_url": f"https://{WORKOS_DOMAIN}/oauth2/jwks",
"issuer": f"https://{WORKOS_DOMAIN}",
"audience": WORKOS_ENVIRONMENT_ID,
"jwks_url": self.oauth2_provider.get_jwks_url(),
"issuer": issuer,
"audience": self.oauth2_provider.get_audience(),
}
# TODO: WORKOS - The following conditional is temporary until migration to WorkOS is complete.
if self.user_provider == "auth0":
jwt_token_data["jwks_url"] = f"https://{AUTH0_DOMAIN}/.well-known/jwks.json"
jwt_token_data["issuer"] = f"https://{AUTH0_DOMAIN}/"
jwt_token_data["audience"] = AUTH0_AUDIENCE
decoded_token = validate_jwt_token(**jwt_token_data)
expires_at = decoded_token.get("exp", 0)

View File

@@ -0,0 +1,26 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth/device/code"
def get_token_url(self) -> str:
return f"https://{self._get_domain()}/oauth/token"
def get_jwks_url(self) -> str:
return f"https://{self._get_domain()}/.well-known/jwks.json"
def get_issuer(self) -> str:
return f"https://{self._get_domain()}/"
def get_audience(self) -> str:
assert self.settings.audience is not None, "Audience is required"
return self.settings.audience
def get_client_id(self) -> str:
assert self.settings.client_id is not None, "Client ID is required"
return self.settings.client_id
def _get_domain(self) -> str:
assert self.settings.domain is not None, "Domain is required"
return self.settings.domain

View File

@@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from crewai.cli.authentication.main import Oauth2Settings
class BaseProvider(ABC):
def __init__(self, settings: Oauth2Settings):
self.settings = settings
@abstractmethod
def get_authorize_url(self) -> str:
...
@abstractmethod
def get_token_url(self) -> str:
...
@abstractmethod
def get_jwks_url(self) -> str:
...
@abstractmethod
def get_issuer(self) -> str:
...
@abstractmethod
def get_audience(self) -> str:
...
@abstractmethod
def get_client_id(self) -> str:
...

View File

@@ -0,0 +1,22 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
def get_token_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/token"
def get_jwks_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/keys"
def get_issuer(self) -> str:
return f"https://{self.settings.domain}/oauth2/default"
def get_audience(self) -> str:
assert self.settings.audience is not None
return self.settings.audience
def get_client_id(self) -> str:
assert self.settings.client_id is not None
return self.settings.client_id

View File

@@ -0,0 +1,25 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/device_authorization"
def get_token_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/token"
def get_jwks_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/jwks"
def get_issuer(self) -> str:
return f"https://{self._get_domain()}"
def get_audience(self) -> str:
return self.settings.audience or ""
def get_client_id(self) -> str:
assert self.settings.client_id is not None, "Client ID is required"
return self.settings.client_id
def _get_domain(self) -> str:
assert self.settings.domain is not None, "Domain is required"
return self.settings.domain

View File

@@ -4,7 +4,13 @@ from typing import Optional
from pydantic import BaseModel, Field
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
from crewai.cli.constants import (
DEFAULT_CREWAI_ENTERPRISE_URL,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
)
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
@@ -19,11 +25,19 @@ USER_SETTINGS_KEYS = [
# Settings that are related to the CLI
CLI_SETTINGS_KEYS = [
"enterprise_base_url",
"oauth2_provider",
"oauth2_audience",
"oauth2_client_id",
"oauth2_domain",
]
# Default values for CLI settings
DEFAULT_CLI_SETTINGS = {
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
"oauth2_provider": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER,
"oauth2_audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
"oauth2_client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
"oauth2_domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
}
# Readonly settings - cannot be set by the user
@@ -39,10 +53,9 @@ HIDDEN_SETTINGS_KEYS = [
"tool_repository_password",
]
class Settings(BaseModel):
enterprise_base_url: Optional[str] = Field(
default=DEFAULT_CREWAI_ENTERPRISE_URL,
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
description="Base URL of the CrewAI Enterprise instance",
)
tool_repository_username: Optional[str] = Field(
@@ -59,6 +72,26 @@ class Settings(BaseModel):
)
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
oauth2_provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
default=DEFAULT_CLI_SETTINGS["oauth2_provider"]
)
oauth2_audience: Optional[str] = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=DEFAULT_CLI_SETTINGS["oauth2_audience"]
)
oauth2_client_id: str = Field(
default=DEFAULT_CLI_SETTINGS["oauth2_client_id"],
description="OAuth2 client ID issued by the provider, used during authentication requests.",
)
oauth2_domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
default=DEFAULT_CLI_SETTINGS["oauth2_domain"]
)
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
"""Load Settings from config path"""
config_path.parent.mkdir(parents=True, exist_ok=True)
@@ -105,4 +138,4 @@ class Settings(BaseModel):
def _reset_cli_settings(self) -> None:
"""Reset all CLI settings to default values"""
for key in CLI_SETTINGS_KEYS:
setattr(self, key, DEFAULT_CLI_SETTINGS[key])
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))

View File

@@ -1,4 +1,8 @@
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"
ENV_VARS = {
"openai": [

View File

@@ -0,0 +1,91 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.auth0 import Auth0Provider
class TestAuth0Provider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="auth0",
domain="test-domain.auth0.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = Auth0Provider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = Auth0Provider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "auth0"
assert provider.settings.domain == "test-domain.auth0.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.auth0.com/oauth/device/code"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="my-company.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://my-company.auth0.com/oauth/device/code"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.auth0.com/oauth/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="another-domain.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://another-domain.auth0.com/oauth/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.auth0.com/.well-known/jwks.json"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="dev.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_url = "https://dev.auth0.com/.well-known/jwks.json"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.auth0.com/"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="auth0",
domain="prod.auth0.com",
client_id="test-client",
audience="test-audience"
)
provider = Auth0Provider(settings)
expected_issuer = "https://prod.auth0.com/"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -0,0 +1,102 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.okta import OktaProvider
class TestOktaProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = OktaProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = OktaProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "okta"
assert provider.settings.domain == "test-domain.okta.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="my-company.okta.com",
client_id="test-client",
audience="test-audience"
)
provider = OktaProvider(settings)
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="another-domain.okta.com",
client_id="test-client",
audience="test-audience"
)
provider = OktaProvider(settings)
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="dev.okta.com",
client_id="test-client",
audience="test-audience"
)
provider = OktaProvider(settings)
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.okta.com/oauth2/default"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="okta",
domain="prod.okta.com",
client_id="test-client",
audience="test-audience"
)
provider = OktaProvider(settings)
expected_issuer = "https://prod.okta.com/oauth2/default"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_assertion_error_when_none(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None
)
provider = OktaProvider(settings)
with pytest.raises(AssertionError):
provider.get_audience()
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -0,0 +1,100 @@
import pytest
from crewai.cli.authentication.main import Oauth2Settings
from crewai.cli.authentication.providers.workos import WorkosProvider
class TestWorkosProvider:
@pytest.fixture(autouse=True)
def setup_method(self):
self.valid_settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience="test-audience"
)
self.provider = WorkosProvider(self.valid_settings)
def test_initialization_with_valid_settings(self):
provider = WorkosProvider(self.valid_settings)
assert provider.settings == self.valid_settings
assert provider.settings.provider == "workos"
assert provider.settings.domain == "login.company.com"
assert provider.settings.client_id == "test-client-id"
assert provider.settings.audience == "test-audience"
def test_get_authorize_url(self):
expected_url = "https://login.company.com/oauth2/device_authorization"
assert self.provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="login.example.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://login.example.com/oauth2/device_authorization"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://login.company.com/oauth2/token"
assert self.provider.get_token_url() == expected_url
def test_get_token_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="api.workos.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://api.workos.com/oauth2/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://login.company.com/oauth2/jwks"
assert self.provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="auth.enterprise.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_url = "https://auth.enterprise.com/oauth2/jwks"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://login.company.com"
assert self.provider.get_issuer() == expected_issuer
def test_get_issuer_with_different_domain(self):
settings = Oauth2Settings(
provider="workos",
domain="sso.company.com",
client_id="test-client",
audience="test-audience"
)
provider = WorkosProvider(settings)
expected_issuer = "https://sso.company.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
def test_get_audience_fallback_to_default(self):
settings = Oauth2Settings(
provider="workos",
domain="login.company.com",
client_id="test-client-id",
audience=None
)
provider = WorkosProvider(settings)
assert provider.get_audience() == ""
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"

View File

@@ -6,10 +6,12 @@ from crewai.cli.authentication.main import AuthenticationCommand
from crewai.cli.authentication.constants import (
AUTH0_AUDIENCE,
AUTH0_CLIENT_ID,
AUTH0_DOMAIN,
WORKOS_DOMAIN,
WORKOS_CLI_CONNECT_APP_ID,
WORKOS_ENVIRONMENT_ID,
AUTH0_DOMAIN
)
from crewai.cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
)
@@ -27,14 +29,17 @@ class TestAuthenticationCommand:
"token_url": f"https://{AUTH0_DOMAIN}/oauth/token",
"client_id": AUTH0_CLIENT_ID,
"audience": AUTH0_AUDIENCE,
"domain": AUTH0_DOMAIN,
},
),
(
"workos",
{
"device_code_url": f"https://{WORKOS_DOMAIN}/oauth2/device_authorization",
"token_url": f"https://{WORKOS_DOMAIN}/oauth2/token",
"client_id": WORKOS_CLI_CONNECT_APP_ID,
"device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization",
"token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token",
"client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
"domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
},
),
],
@@ -70,19 +75,16 @@ class TestAuthenticationCommand:
"Signing in to CrewAI Enterprise...\n", style="bold blue"
)
mock_determine_provider.assert_called_once()
mock_get_device.assert_called_once_with(
expected_urls["client_id"],
expected_urls["device_code_url"],
expected_urls.get("audience", None),
)
mock_get_device.assert_called_once()
mock_display.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"}
)
mock_poll.assert_called_once_with(
{"device_code": "test_code", "user_code": "123456"},
expected_urls["client_id"],
expected_urls["token_url"],
)
assert self.auth_command.oauth2_provider.get_client_id() == expected_urls["client_id"]
assert self.auth_command.oauth2_provider.get_audience() == expected_urls["audience"]
assert self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
@patch("crewai.cli.authentication.main.webbrowser")
@patch("crewai.cli.authentication.main.console.print")
@@ -115,9 +117,9 @@ class TestAuthenticationCommand:
(
"workos",
{
"jwks_url": f"https://{WORKOS_DOMAIN}/oauth2/jwks",
"issuer": f"https://{WORKOS_DOMAIN}",
"audience": WORKOS_ENVIRONMENT_ID,
"jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks",
"issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}",
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
},
),
],
@@ -133,7 +135,15 @@ class TestAuthenticationCommand:
jwt_config,
has_expiration,
):
self.auth_command.user_provider = user_provider
from crewai.cli.authentication.providers.auth0 import Auth0Provider
from crewai.cli.authentication.providers.workos import WorkosProvider
from crewai.cli.authentication.main import Oauth2Settings
if user_provider == "auth0":
self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=AUTH0_DOMAIN, audience=jwt_config["audience"]))
elif user_provider == "workos":
self.auth_command.oauth2_provider = WorkosProvider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, audience=jwt_config["audience"]))
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
if has_expiration:
@@ -311,11 +321,12 @@ class TestAuthenticationCommand:
}
mock_post.return_value = mock_response
result = self.auth_command._get_device_code(
client_id="test_client",
device_code_url="https://example.com/device",
audience="test_audience",
)
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command.oauth2_provider.get_authorize_url.return_value = "https://example.com/device"
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
result = self.auth_command._get_device_code()
mock_post.assert_called_once_with(
url="https://example.com/device",
@@ -354,8 +365,12 @@ class TestAuthenticationCommand:
self.auth_command, "_login_to_tool_repository"
) as mock_tool_login,
):
self.auth_command.oauth2_provider = MagicMock()
self.auth_command.oauth2_provider.get_token_url.return_value = "https://example.com/token"
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
self.auth_command._poll_for_token(
device_code_data, "test_client", "https://example.com/token"
device_code_data
)
mock_post.assert_called_once_with(
@@ -392,7 +407,7 @@ class TestAuthenticationCommand:
}
self.auth_command._poll_for_token(
device_code_data, "test_client", "https://example.com/token"
device_code_data
)
mock_console_print.assert_any_call(
@@ -415,5 +430,14 @@ class TestAuthenticationCommand:
with pytest.raises(requests.HTTPError):
self.auth_command._poll_for_token(
device_code_data, "test_client", "https://example.com/token"
device_code_data
)
# @patch(
# "crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
# )
# def test_login_with_auth0(self, mock_determine_provider):
# from crewai.cli.authentication.providers.auth0 import Auth0Provider
# from crewai.cli.authentication.main import Oauth2Settings
# self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider="auth0", client_id=AUTH0_CLIENT_ID, domain=AUTH0_DOMAIN, audience=AUTH0_AUDIENCE))
# self.auth_command.login()

View File

@@ -79,7 +79,7 @@ class TestSettings(unittest.TestCase):
for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None)
for key in cli_settings.keys():
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS[key])
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
def test_dump_new_settings(self):
settings = Settings(

View File

@@ -81,11 +81,10 @@ class TestSettingsCommand(unittest.TestCase):
self.settings_command.reset_all_settings()
print(USER_SETTINGS_KEYS)
for key in USER_SETTINGS_KEYS:
self.assertEqual(getattr(self.settings_command.settings, key), None)
for key in CLI_SETTINGS_KEYS:
self.assertEqual(
getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS[key]
getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS.get(key)
)