feat: fetch and store more data about okta authorization server (#3894)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit is contained in:
Heitor Carvalho
2025-11-12 15:28:00 -03:00
committed by GitHub
parent c205d2e8de
commit fbe4aa4bd1
13 changed files with 323 additions and 83 deletions

View File

@@ -1,5 +1,5 @@
import time
from typing import Any
from typing import TYPE_CHECKING, Any, TypeVar, cast
import webbrowser
from pydantic import BaseModel, Field
@@ -13,6 +13,8 @@ from crewai.cli.shared.token_manager import TokenManager
console = Console()
TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
class Oauth2Settings(BaseModel):
provider: str = Field(
@@ -28,9 +30,15 @@ class Oauth2Settings(BaseModel):
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=None,
)
extra: dict[str, Any] = Field(
description="Extra configuration for the OAuth2 provider.",
default={},
)
@classmethod
def from_settings(cls):
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
"""Create an Oauth2Settings instance from the CLI settings."""
settings = Settings()
return cls(
@@ -38,12 +46,20 @@ class Oauth2Settings(BaseModel):
domain=settings.oauth2_domain,
client_id=settings.oauth2_client_id,
audience=settings.oauth2_audience,
extra=settings.oauth2_extra,
)
if TYPE_CHECKING:
from crewai.cli.authentication.providers.base_provider import BaseProvider
class ProviderFactory:
@classmethod
def from_settings(cls, settings: Oauth2Settings | None = None):
def from_settings(
cls: type["ProviderFactory"], # noqa: UP037
settings: Oauth2Settings | None = None,
) -> "BaseProvider": # noqa: UP037
settings = settings or Oauth2Settings.from_settings()
import importlib
@@ -53,11 +69,11 @@ class ProviderFactory:
)
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
return provider(settings)
return cast("BaseProvider", provider(settings))
class AuthenticationCommand:
def __init__(self):
def __init__(self) -> None:
self.token_manager = TokenManager()
self.oauth2_provider = ProviderFactory.from_settings()
@@ -84,7 +100,7 @@ class AuthenticationCommand:
timeout=20,
)
response.raise_for_status()
return response.json()
return cast(dict[str, Any], response.json())
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
"""Display the authentication instructions to the user."""

View File

@@ -24,3 +24,7 @@ class BaseProvider(ABC):
@abstractmethod
def get_client_id(self) -> str: ...
def get_required_fields(self) -> list[str]:
"""Returns which provider-specific fields inside the "extra" dict will be required"""
return []

View File

@@ -3,16 +3,16 @@ from crewai.cli.authentication.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
return f"{self._oauth2_base_url()}/v1/device/authorize"
def get_token_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/token"
return f"{self._oauth2_base_url()}/v1/token"
def get_jwks_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/keys"
return f"{self._oauth2_base_url()}/v1/keys"
def get_issuer(self) -> str:
return f"https://{self.settings.domain}/oauth2/default"
return self._oauth2_base_url().removesuffix("/oauth2")
def get_audience(self) -> str:
if self.settings.audience is None:
@@ -27,3 +27,16 @@ class OktaProvider(BaseProvider):
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def get_required_fields(self) -> list[str]:
return ["authorization_server_name", "using_org_auth_server"]
def _oauth2_base_url(self) -> str:
using_org_auth_server = self.settings.extra.get("using_org_auth_server", False)
if using_org_auth_server:
base_url = f"https://{self.settings.domain}/oauth2"
else:
base_url = f"https://{self.settings.domain}/oauth2/{self.settings.extra.get('authorization_server_name', 'default')}"
return f"{base_url}"

View File

@@ -11,18 +11,18 @@ console = Console()
class BaseCommand:
def __init__(self):
def __init__(self) -> None:
self._telemetry = Telemetry()
self._telemetry.set_tracer()
class PlusAPIMixin:
def __init__(self, telemetry):
def __init__(self, telemetry: Telemetry) -> None:
try:
telemetry.set_tracer()
self.plus_api_client = PlusAPI(api_key=get_auth_token())
except Exception:
self._deploy_signup_error_span = telemetry.deploy_signup_error_span()
telemetry.deploy_signup_error_span()
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",

View File

@@ -2,6 +2,7 @@ import json
from logging import getLogger
from pathlib import Path
import tempfile
from typing import Any
from pydantic import BaseModel, Field
@@ -136,7 +137,12 @@ class Settings(BaseModel):
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
)
def __init__(self, config_path: Path | None = None, **data):
oauth2_extra: dict[str, Any] = Field(
description="Extra configuration for the OAuth2 provider.",
default={},
)
def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None:
"""Load Settings from config path with fallback support"""
if config_path is None:
config_path = get_writable_config_path()

View File

@@ -1,9 +1,10 @@
from typing import Any
from typing import Any, cast
import requests
from requests.exceptions import JSONDecodeError, RequestException
from rich.console import Console
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
from crewai.cli.command import BaseCommand
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.version import get_crewai_version
@@ -13,7 +14,7 @@ console = Console()
class EnterpriseConfigureCommand(BaseCommand):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.settings_command = SettingsCommand()
@@ -54,25 +55,12 @@ class EnterpriseConfigureCommand(BaseCommand):
except JSONDecodeError as e:
raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e
required_fields = [
"audience",
"domain",
"device_authorization_client_id",
"provider",
]
missing_fields = [
field for field in required_fields if field not in oauth_config
]
if missing_fields:
raise ValueError(
f"Missing required fields in OAuth2 configuration: {', '.join(missing_fields)}"
)
self._validate_oauth_config(oauth_config)
console.print(
"✅ Successfully retrieved OAuth2 configuration", style="green"
)
return oauth_config
return cast(dict[str, Any], oauth_config)
except RequestException as e:
raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
@@ -89,6 +77,7 @@ class EnterpriseConfigureCommand(BaseCommand):
"oauth2_audience": oauth_config["audience"],
"oauth2_client_id": oauth_config["device_authorization_client_id"],
"oauth2_domain": oauth_config["domain"],
"oauth2_extra": oauth_config["extra"],
}
console.print("🔄 Updating local OAuth2 configuration...")
@@ -99,3 +88,38 @@ class EnterpriseConfigureCommand(BaseCommand):
except Exception as e:
raise ValueError(f"Failed to update OAuth2 settings: {e!s}") from e
def _validate_oauth_config(self, oauth_config: dict[str, Any]) -> None:
required_fields = [
"audience",
"domain",
"device_authorization_client_id",
"provider",
"extra",
]
missing_basic_fields = [
field for field in required_fields if field not in oauth_config
]
missing_provider_specific_fields = [
field
for field in self._get_provider_specific_fields(oauth_config["provider"])
if field not in oauth_config.get("extra", {})
]
if missing_basic_fields:
raise ValueError(
f"Missing required fields in OAuth2 configuration: [{', '.join(missing_basic_fields)}]"
)
if missing_provider_specific_fields:
raise ValueError(
f"Missing authentication provider required fields in OAuth2 configuration: [{', '.join(missing_provider_specific_fields)}] (Configured provider: '{oauth_config['provider']}')"
)
def _get_provider_specific_fields(self, provider_name: str) -> list[str]:
provider = ProviderFactory.from_settings(
Oauth2Settings(provider=provider_name, client_id="dummy", domain="dummy")
)
return provider.get_required_fields()

View File

@@ -3,7 +3,7 @@ import subprocess
class Repository:
def __init__(self, path="."):
def __init__(self, path: str = ".") -> None:
self.path = path
if not self.is_git_installed():

View File

@@ -1,3 +1,4 @@
from typing import Any
from urllib.parse import urljoin
import requests
@@ -36,19 +37,21 @@ class PlusAPI:
str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL
)
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
def _make_request(
self, method: str, endpoint: str, **kwargs: Any
) -> requests.Response:
url = urljoin(self.base_url, endpoint)
session = requests.Session()
session.trust_env = False
return session.request(method, url, headers=self.headers, **kwargs)
def login_to_tool_repository(self):
def login_to_tool_repository(self) -> requests.Response:
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
def get_tool(self, handle: str):
def get_tool(self, handle: str) -> requests.Response:
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
def get_agent(self, handle: str):
def get_agent(self, handle: str) -> requests.Response:
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
def publish_tool(
@@ -58,8 +61,8 @@ class PlusAPI:
version: str,
description: str | None,
encoded_file: str,
available_exports: list[str] | None = None,
):
available_exports: list[dict[str, Any]] | None = None,
) -> requests.Response:
params = {
"handle": handle,
"public": is_public,
@@ -111,13 +114,13 @@ class PlusAPI:
def list_crews(self) -> requests.Response:
return self._make_request("GET", self.CREWS_RESOURCE)
def create_crew(self, payload) -> requests.Response:
def create_crew(self, payload: dict[str, Any]) -> requests.Response:
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
def get_organizations(self) -> requests.Response:
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
def initialize_trace_batch(self, payload) -> requests.Response:
def initialize_trace_batch(self, payload: dict[str, Any]) -> requests.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches",
@@ -125,14 +128,18 @@ class PlusAPI:
timeout=30,
)
def initialize_ephemeral_trace_batch(self, payload) -> requests.Response:
def initialize_ephemeral_trace_batch(
self, payload: dict[str, Any]
) -> requests.Response:
return self._make_request(
"POST",
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
json=payload,
)
def send_trace_events(self, trace_batch_id: str, payload) -> requests.Response:
def send_trace_events(
self, trace_batch_id: str, payload: dict[str, Any]
) -> requests.Response:
return self._make_request(
"POST",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
@@ -141,7 +148,7 @@ class PlusAPI:
)
def send_ephemeral_trace_events(
self, trace_batch_id: str, payload
self, trace_batch_id: str, payload: dict[str, Any]
) -> requests.Response:
return self._make_request(
"POST",
@@ -150,7 +157,9 @@ class PlusAPI:
timeout=30,
)
def finalize_trace_batch(self, trace_batch_id: str, payload) -> requests.Response:
def finalize_trace_batch(
self, trace_batch_id: str, payload: dict[str, Any]
) -> requests.Response:
return self._make_request(
"PATCH",
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
@@ -159,7 +168,7 @@ class PlusAPI:
)
def finalize_ephemeral_trace_batch(
self, trace_batch_id: str, payload
self, trace_batch_id: str, payload: dict[str, Any]
) -> requests.Response:
return self._make_request(
"PATCH",

View File

@@ -34,7 +34,7 @@ class SettingsCommand(BaseCommand):
current_value = getattr(self.settings, field_name)
description = field_info.description or "No description available"
display_value = (
str(current_value) if current_value is not None else "Not set"
str(current_value) if current_value not in [None, {}] else "Not set"
)
table.add_row(field_name, display_value, description)

View File

@@ -30,11 +30,11 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
A class to handle tool repository related operations for CrewAI projects.
"""
def __init__(self):
def __init__(self) -> None:
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
def create(self, handle: str):
def create(self, handle: str) -> None:
self._ensure_not_in_project()
folder_name = handle.replace(" ", "_").replace("-", "_").lower()
@@ -64,7 +64,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
finally:
os.chdir(old_directory)
def publish(self, is_public: bool, force: bool = False):
def publish(self, is_public: bool, force: bool = False) -> None:
if not git.Repository().is_synced() and not force:
console.print(
"[bold red]Failed to publish tool.[/bold red]\n"
@@ -137,7 +137,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
style="bold green",
)
def install(self, handle: str):
def install(self, handle: str) -> None:
self._print_current_organization()
get_response = self.plus_api_client.get_tool(handle)
@@ -180,7 +180,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
settings.org_name = login_response_json["current_organization"]["name"]
settings.dump()
def _add_package(self, tool_details: dict[str, Any]):
def _add_package(self, tool_details: dict[str, Any]) -> None:
is_from_pypi = tool_details.get("source", None) == "pypi"
tool_handle = tool_details["handle"]
repository_handle = tool_details["repository"]["handle"]
@@ -209,7 +209,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
click.echo(add_package_result.stderr, err=True)
raise SystemExit
def _ensure_not_in_project(self):
def _ensure_not_in_project(self) -> None:
if os.path.isfile("./pyproject.toml"):
console.print(
"[bold red]Oops! It looks like you're inside a project.[/bold red]"

View File

@@ -5,7 +5,7 @@ import os
from pathlib import Path
import shutil
import sys
from typing import Any, get_type_hints
from typing import Any, cast, get_type_hints
import click
from rich.console import Console
@@ -23,7 +23,9 @@ if sys.version_info >= (3, 11):
console = Console()
def copy_template(src, dst, name, class_name, folder_name):
def copy_template(
src: Path, dst: Path, name: str, class_name: str, folder_name: str
) -> None:
"""Copy a file from src to dst."""
with open(src, "r") as file:
content = file.read()
@@ -40,13 +42,13 @@ def copy_template(src, dst, name, class_name, folder_name):
click.secho(f" - Created {dst}", fg="green")
def read_toml(file_path: str = "pyproject.toml"):
def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
"""Read the content of a TOML file and return it as a dictionary."""
with open(file_path, "rb") as f:
return tomli.load(f)
def parse_toml(content):
def parse_toml(content: str) -> dict[str, Any]:
if sys.version_info >= (3, 11):
return tomllib.loads(content)
return tomli.loads(content)
@@ -103,7 +105,7 @@ def _get_project_attribute(
)
except Exception as e:
# Handle TOML decode errors for Python 3.11+
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError): # type: ignore
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError):
console.print(
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
)
@@ -126,7 +128,7 @@ def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
"""Fetch the environment variables from a .env file and return them as a dictionary."""
try:
# Read the .env file
@@ -150,7 +152,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
return {}
def tree_copy(source, destination):
def tree_copy(source: Path, destination: Path) -> None:
"""Copies the entire directory structure from the source to the destination."""
for item in os.listdir(source):
source_item = os.path.join(source, item)
@@ -161,7 +163,7 @@ def tree_copy(source, destination):
shutil.copy2(source_item, destination_item)
def tree_find_and_replace(directory, find, replace):
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
"""Recursively searches through a directory, replacing a target string in
both file contents and filenames with a specified replacement string.
"""
@@ -187,7 +189,7 @@ def tree_find_and_replace(directory, find, replace):
os.rename(old_dirpath, new_dirpath)
def load_env_vars(folder_path):
def load_env_vars(folder_path: Path) -> dict[str, Any]:
"""
Loads environment variables from a .env file in the specified folder path.
@@ -208,7 +210,9 @@ def load_env_vars(folder_path):
return env_vars
def update_env_vars(env_vars, provider, model):
def update_env_vars(
env_vars: dict[str, Any], provider: str, model: str
) -> dict[str, Any] | None:
"""
Updates environment variables with the API key for the selected provider and model.
@@ -220,15 +224,20 @@ def update_env_vars(env_vars, provider, model):
Returns:
- None
"""
api_key_var = ENV_VARS.get(
provider,
[
click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
)
],
)[0]
provider_config = cast(
list[str],
ENV_VARS.get(
provider,
[
click.prompt(
f"Enter the environment variable name for your {provider.capitalize()} API key",
type=str,
)
],
),
)
api_key_var = provider_config[0]
if api_key_var not in env_vars:
try:
@@ -246,7 +255,7 @@ def update_env_vars(env_vars, provider, model):
return env_vars
def write_env_file(folder_path, env_vars):
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
"""
Writes environment variables to a .env file in the specified folder.
@@ -342,18 +351,18 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
return crew_instances
def get_crew_instance(module_attr) -> Crew | None:
def get_crew_instance(module_attr: Any) -> Crew | None:
if (
callable(module_attr)
and hasattr(module_attr, "is_crew_class")
and module_attr.is_crew_class
):
return module_attr().crew()
return cast(Crew, module_attr().crew())
try:
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
module_attr
).get("return") is Crew:
return module_attr()
return cast(Crew, module_attr())
except Exception:
return None
@@ -362,7 +371,7 @@ def get_crew_instance(module_attr) -> Crew | None:
return None
def fetch_crews(module_attr) -> list[Crew]:
def fetch_crews(module_attr: Any) -> list[Crew]:
crew_instances: list[Crew] = []
if crew_instance := get_crew_instance(module_attr):
@@ -377,7 +386,7 @@ def fetch_crews(module_attr) -> list[Crew]:
return crew_instances
def is_valid_tool(obj):
def is_valid_tool(obj: Any) -> bool:
from crewai.tools.base_tool import Tool
if isclass(obj):
@@ -389,7 +398,7 @@ def is_valid_tool(obj):
return isinstance(obj, Tool)
def extract_available_exports(dir_path: str = "src"):
def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
"""
Extract available tool classes from the project's __init__.py files.
Only includes classes that inherit from BaseTool or functions decorated with @tool.
@@ -419,7 +428,9 @@ def extract_available_exports(dir_path: str = "src"):
raise SystemExit(1) from e
def build_env_with_tool_repository_credentials(repository_handle: str):
def build_env_with_tool_repository_credentials(
repository_handle: str,
) -> dict[str, Any]:
repository_handle = repository_handle.upper().replace("-", "_")
settings = Settings()
@@ -472,7 +483,7 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
sys.modules.pop("temp_module", None)
def _print_no_tools_warning():
def _print_no_tools_warning() -> None:
"""
Display warning and usage instructions if no tools were found.
"""

View File

@@ -37,6 +37,36 @@ class TestOktaProvider:
provider = OktaProvider(settings)
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_authorize_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_authorize_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize"
assert provider.get_authorize_url() == expected_url
def test_get_token_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
@@ -53,6 +83,36 @@ class TestOktaProvider:
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
assert provider.get_token_url() == expected_url
def test_get_token_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token"
assert provider.get_token_url() == expected_url
def test_get_token_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/token"
assert provider.get_token_url() == expected_url
def test_get_jwks_url(self):
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
assert self.provider.get_jwks_url() == expected_url
@@ -68,6 +128,36 @@ class TestOktaProvider:
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_jwks_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_jwks_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_url = "https://test-domain.okta.com/oauth2/v1/keys"
assert provider.get_jwks_url() == expected_url
def test_get_issuer(self):
expected_issuer = "https://test-domain.okta.com/oauth2/default"
assert self.provider.get_issuer() == expected_issuer
@@ -83,6 +173,36 @@ class TestOktaProvider:
expected_issuer = "https://prod.okta.com/oauth2/default"
assert provider.get_issuer() == expected_issuer
def test_get_issuer_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
assert provider.get_issuer() == expected_issuer
def test_get_issuer_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
expected_issuer = "https://test-domain.okta.com"
assert provider.get_issuer() == expected_issuer
def test_get_audience(self):
assert self.provider.get_audience() == "test-audience"
@@ -100,3 +220,38 @@ class TestOktaProvider:
def test_get_client_id(self):
assert self.provider.get_client_id() == "test-client-id"
def test_get_required_fields(self):
assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"])
def test_oauth2_base_url(self):
assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default"
def test_oauth2_base_url_with_custom_authorization_server_name(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": False,
"authorization_server_name": "my_auth_server_xxxAAA777"
}
)
provider = OktaProvider(settings)
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
def test_oauth2_base_url_when_using_org_auth_server(self):
settings = Oauth2Settings(
provider="okta",
domain="test-domain.okta.com",
client_id="test-client-id",
audience=None,
extra={
"using_org_auth_server": True,
"authorization_server_name": None
}
)
provider = OktaProvider(settings)
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2"

View File

@@ -37,7 +37,8 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
'audience': 'test_audience',
'domain': 'test.domain.com',
'device_authorization_client_id': 'test_client_id',
'provider': 'workos'
'provider': 'workos',
'extra': {}
}
mock_requests_get.return_value = mock_response
@@ -60,11 +61,12 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
('oauth2_provider', 'workos'),
('oauth2_audience', 'test_audience'),
('oauth2_client_id', 'test_client_id'),
('oauth2_domain', 'test.domain.com')
('oauth2_domain', 'test.domain.com'),
('oauth2_extra', {})
]
actual_calls = self.mock_settings_command.set.call_args_list
self.assertEqual(len(actual_calls), 5)
self.assertEqual(len(actual_calls), 6)
for i, (key, value) in enumerate(expected_calls):
call_args = actual_calls[i][0]