diff --git a/lib/crewai/src/crewai/cli/authentication/main.py b/lib/crewai/src/crewai/cli/authentication/main.py index b23fe9114..7bda8fe08 100644 --- a/lib/crewai/src/crewai/cli/authentication/main.py +++ b/lib/crewai/src/crewai/cli/authentication/main.py @@ -1,5 +1,5 @@ import time -from typing import Any +from typing import TYPE_CHECKING, Any, TypeVar, cast import webbrowser from pydantic import BaseModel, Field @@ -13,6 +13,8 @@ from crewai.cli.shared.token_manager import TokenManager console = Console() +TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings") + class Oauth2Settings(BaseModel): provider: str = Field( @@ -28,9 +30,15 @@ class Oauth2Settings(BaseModel): description="OAuth2 audience value, typically used to identify the target API or resource.", default=None, ) + extra: dict[str, Any] = Field( + description="Extra configuration for the OAuth2 provider.", + default={}, + ) @classmethod - def from_settings(cls): + def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings: + """Create an Oauth2Settings instance from the CLI settings.""" + settings = Settings() return cls( @@ -38,12 +46,20 @@ class Oauth2Settings(BaseModel): domain=settings.oauth2_domain, client_id=settings.oauth2_client_id, audience=settings.oauth2_audience, + extra=settings.oauth2_extra, ) +if TYPE_CHECKING: + from crewai.cli.authentication.providers.base_provider import BaseProvider + + class ProviderFactory: @classmethod - def from_settings(cls, settings: Oauth2Settings | None = None): + def from_settings( + cls: type["ProviderFactory"], # noqa: UP037 + settings: Oauth2Settings | None = None, + ) -> "BaseProvider": # noqa: UP037 settings = settings or Oauth2Settings.from_settings() import importlib @@ -53,11 +69,11 @@ class ProviderFactory: ) provider = getattr(module, f"{settings.provider.capitalize()}Provider") - return provider(settings) + return cast("BaseProvider", provider(settings)) class AuthenticationCommand: - def __init__(self): + def __init__(self) -> None: self.token_manager = TokenManager() self.oauth2_provider = ProviderFactory.from_settings() @@ -84,7 +100,7 @@ class AuthenticationCommand: timeout=20, ) response.raise_for_status() - return response.json() + return cast(dict[str, Any], response.json()) def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None: """Display the authentication instructions to the user.""" diff --git a/lib/crewai/src/crewai/cli/authentication/providers/base_provider.py b/lib/crewai/src/crewai/cli/authentication/providers/base_provider.py index 2b7a0140e..0c8057b4d 100644 --- a/lib/crewai/src/crewai/cli/authentication/providers/base_provider.py +++ b/lib/crewai/src/crewai/cli/authentication/providers/base_provider.py @@ -24,3 +24,7 @@ class BaseProvider(ABC): @abstractmethod def get_client_id(self) -> str: ... + + def get_required_fields(self) -> list[str]: + """Returns which provider-specific fields inside the "extra" dict will be required""" + return [] diff --git a/lib/crewai/src/crewai/cli/authentication/providers/okta.py b/lib/crewai/src/crewai/cli/authentication/providers/okta.py index d13087e7d..90f5e2908 100644 --- a/lib/crewai/src/crewai/cli/authentication/providers/okta.py +++ b/lib/crewai/src/crewai/cli/authentication/providers/okta.py @@ -3,16 +3,16 @@ from crewai.cli.authentication.providers.base_provider import BaseProvider class OktaProvider(BaseProvider): def get_authorize_url(self) -> str: - return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize" + return f"{self._oauth2_base_url()}/v1/device/authorize" def get_token_url(self) -> str: - return f"https://{self.settings.domain}/oauth2/default/v1/token" + return f"{self._oauth2_base_url()}/v1/token" def get_jwks_url(self) -> str: - return f"https://{self.settings.domain}/oauth2/default/v1/keys" + return f"{self._oauth2_base_url()}/v1/keys" def get_issuer(self) -> str: - return f"https://{self.settings.domain}/oauth2/default" + return self._oauth2_base_url().removesuffix("/oauth2") def get_audience(self) -> str: if self.settings.audience is None: @@ -27,3 +27,16 @@ class OktaProvider(BaseProvider): "Client ID is required. Please set it in the configuration." ) return self.settings.client_id + + def get_required_fields(self) -> list[str]: + return ["authorization_server_name", "using_org_auth_server"] + + def _oauth2_base_url(self) -> str: + using_org_auth_server = self.settings.extra.get("using_org_auth_server", False) + + if using_org_auth_server: + base_url = f"https://{self.settings.domain}/oauth2" + else: + base_url = f"https://{self.settings.domain}/oauth2/{self.settings.extra.get('authorization_server_name', 'default')}" + + return f"{base_url}" diff --git a/lib/crewai/src/crewai/cli/command.py b/lib/crewai/src/crewai/cli/command.py index e889b7125..3f85318fb 100644 --- a/lib/crewai/src/crewai/cli/command.py +++ b/lib/crewai/src/crewai/cli/command.py @@ -11,18 +11,18 @@ console = Console() class BaseCommand: - def __init__(self): + def __init__(self) -> None: self._telemetry = Telemetry() self._telemetry.set_tracer() class PlusAPIMixin: - def __init__(self, telemetry): + def __init__(self, telemetry: Telemetry) -> None: try: telemetry.set_tracer() self.plus_api_client = PlusAPI(api_key=get_auth_token()) except Exception: - self._deploy_signup_error_span = telemetry.deploy_signup_error_span() + telemetry.deploy_signup_error_span() console.print( "Please sign up/login to CrewAI+ before using the CLI.", style="bold red", diff --git a/lib/crewai/src/crewai/cli/config.py b/lib/crewai/src/crewai/cli/config.py index dea3691ae..7af9904e0 100644 --- a/lib/crewai/src/crewai/cli/config.py +++ b/lib/crewai/src/crewai/cli/config.py @@ -2,6 +2,7 @@ import json from logging import getLogger from pathlib import Path import tempfile +from typing import Any from pydantic import BaseModel, Field @@ -136,7 +137,12 @@ class Settings(BaseModel): default=DEFAULT_CLI_SETTINGS["oauth2_domain"], ) - def __init__(self, config_path: Path | None = None, **data): + oauth2_extra: dict[str, Any] = Field( + description="Extra configuration for the OAuth2 provider.", + default={}, + ) + + def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None: """Load Settings from config path with fallback support""" if config_path is None: config_path = get_writable_config_path() diff --git a/lib/crewai/src/crewai/cli/enterprise/main.py b/lib/crewai/src/crewai/cli/enterprise/main.py index 62002608e..2a73f1ae0 100644 --- a/lib/crewai/src/crewai/cli/enterprise/main.py +++ b/lib/crewai/src/crewai/cli/enterprise/main.py @@ -1,9 +1,10 @@ -from typing import Any +from typing import Any, cast import requests from requests.exceptions import JSONDecodeError, RequestException from rich.console import Console +from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory from crewai.cli.command import BaseCommand from crewai.cli.settings.main import SettingsCommand from crewai.cli.version import get_crewai_version @@ -13,7 +14,7 @@ console = Console() class EnterpriseConfigureCommand(BaseCommand): - def __init__(self): + def __init__(self) -> None: super().__init__() self.settings_command = SettingsCommand() @@ -54,25 +55,12 @@ class EnterpriseConfigureCommand(BaseCommand): except JSONDecodeError as e: raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e - required_fields = [ - "audience", - "domain", - "device_authorization_client_id", - "provider", - ] - missing_fields = [ - field for field in required_fields if field not in oauth_config - ] - - if missing_fields: - raise ValueError( - f"Missing required fields in OAuth2 configuration: {', '.join(missing_fields)}" - ) + self._validate_oauth_config(oauth_config) console.print( "✅ Successfully retrieved OAuth2 configuration", style="green" ) - return oauth_config + return cast(dict[str, Any], oauth_config) except RequestException as e: raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e @@ -89,6 +77,7 @@ class EnterpriseConfigureCommand(BaseCommand): "oauth2_audience": oauth_config["audience"], "oauth2_client_id": oauth_config["device_authorization_client_id"], "oauth2_domain": oauth_config["domain"], + "oauth2_extra": oauth_config["extra"], } console.print("🔄 Updating local OAuth2 configuration...") @@ -99,3 +88,38 @@ class EnterpriseConfigureCommand(BaseCommand): except Exception as e: raise ValueError(f"Failed to update OAuth2 settings: {e!s}") from e + + def _validate_oauth_config(self, oauth_config: dict[str, Any]) -> None: + required_fields = [ + "audience", + "domain", + "device_authorization_client_id", + "provider", + "extra", + ] + + missing_basic_fields = [ + field for field in required_fields if field not in oauth_config + ] + missing_provider_specific_fields = [ + field + for field in self._get_provider_specific_fields(oauth_config["provider"]) + if field not in oauth_config.get("extra", {}) + ] + + if missing_basic_fields: + raise ValueError( + f"Missing required fields in OAuth2 configuration: [{', '.join(missing_basic_fields)}]" + ) + + if missing_provider_specific_fields: + raise ValueError( + f"Missing authentication provider required fields in OAuth2 configuration: [{', '.join(missing_provider_specific_fields)}] (Configured provider: '{oauth_config['provider']}')" + ) + + def _get_provider_specific_fields(self, provider_name: str) -> list[str]: + provider = ProviderFactory.from_settings( + Oauth2Settings(provider=provider_name, client_id="dummy", domain="dummy") + ) + + return provider.get_required_fields() diff --git a/lib/crewai/src/crewai/cli/git.py b/lib/crewai/src/crewai/cli/git.py index b493e88c0..fb08c391a 100644 --- a/lib/crewai/src/crewai/cli/git.py +++ b/lib/crewai/src/crewai/cli/git.py @@ -3,7 +3,7 @@ import subprocess class Repository: - def __init__(self, path="."): + def __init__(self, path: str = ".") -> None: self.path = path if not self.is_git_installed(): diff --git a/lib/crewai/src/crewai/cli/plus_api.py b/lib/crewai/src/crewai/cli/plus_api.py index 6121dd718..5d7141179 100644 --- a/lib/crewai/src/crewai/cli/plus_api.py +++ b/lib/crewai/src/crewai/cli/plus_api.py @@ -1,3 +1,4 @@ +from typing import Any from urllib.parse import urljoin import requests @@ -36,19 +37,21 @@ class PlusAPI: str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL ) - def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response: + def _make_request( + self, method: str, endpoint: str, **kwargs: Any + ) -> requests.Response: url = urljoin(self.base_url, endpoint) session = requests.Session() session.trust_env = False return session.request(method, url, headers=self.headers, **kwargs) - def login_to_tool_repository(self): + def login_to_tool_repository(self) -> requests.Response: return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login") - def get_tool(self, handle: str): + def get_tool(self, handle: str) -> requests.Response: return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}") - def get_agent(self, handle: str): + def get_agent(self, handle: str) -> requests.Response: return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}") def publish_tool( @@ -58,8 +61,8 @@ class PlusAPI: version: str, description: str | None, encoded_file: str, - available_exports: list[str] | None = None, - ): + available_exports: list[dict[str, Any]] | None = None, + ) -> requests.Response: params = { "handle": handle, "public": is_public, @@ -111,13 +114,13 @@ class PlusAPI: def list_crews(self) -> requests.Response: return self._make_request("GET", self.CREWS_RESOURCE) - def create_crew(self, payload) -> requests.Response: + def create_crew(self, payload: dict[str, Any]) -> requests.Response: return self._make_request("POST", self.CREWS_RESOURCE, json=payload) def get_organizations(self) -> requests.Response: return self._make_request("GET", self.ORGANIZATIONS_RESOURCE) - def initialize_trace_batch(self, payload) -> requests.Response: + def initialize_trace_batch(self, payload: dict[str, Any]) -> requests.Response: return self._make_request( "POST", f"{self.TRACING_RESOURCE}/batches", @@ -125,14 +128,18 @@ class PlusAPI: timeout=30, ) - def initialize_ephemeral_trace_batch(self, payload) -> requests.Response: + def initialize_ephemeral_trace_batch( + self, payload: dict[str, Any] + ) -> requests.Response: return self._make_request( "POST", f"{self.EPHEMERAL_TRACING_RESOURCE}/batches", json=payload, ) - def send_trace_events(self, trace_batch_id: str, payload) -> requests.Response: + def send_trace_events( + self, trace_batch_id: str, payload: dict[str, Any] + ) -> requests.Response: return self._make_request( "POST", f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events", @@ -141,7 +148,7 @@ class PlusAPI: ) def send_ephemeral_trace_events( - self, trace_batch_id: str, payload + self, trace_batch_id: str, payload: dict[str, Any] ) -> requests.Response: return self._make_request( "POST", @@ -150,7 +157,9 @@ class PlusAPI: timeout=30, ) - def finalize_trace_batch(self, trace_batch_id: str, payload) -> requests.Response: + def finalize_trace_batch( + self, trace_batch_id: str, payload: dict[str, Any] + ) -> requests.Response: return self._make_request( "PATCH", f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize", @@ -159,7 +168,7 @@ class PlusAPI: ) def finalize_ephemeral_trace_batch( - self, trace_batch_id: str, payload + self, trace_batch_id: str, payload: dict[str, Any] ) -> requests.Response: return self._make_request( "PATCH", diff --git a/lib/crewai/src/crewai/cli/settings/main.py b/lib/crewai/src/crewai/cli/settings/main.py index 3fa4f2af0..83a50c2fe 100644 --- a/lib/crewai/src/crewai/cli/settings/main.py +++ b/lib/crewai/src/crewai/cli/settings/main.py @@ -34,7 +34,7 @@ class SettingsCommand(BaseCommand): current_value = getattr(self.settings, field_name) description = field_info.description or "No description available" display_value = ( - str(current_value) if current_value is not None else "Not set" + str(current_value) if current_value not in [None, {}] else "Not set" ) table.add_row(field_name, display_value, description) diff --git a/lib/crewai/src/crewai/cli/tools/main.py b/lib/crewai/src/crewai/cli/tools/main.py index 09bc927d3..2705388c5 100644 --- a/lib/crewai/src/crewai/cli/tools/main.py +++ b/lib/crewai/src/crewai/cli/tools/main.py @@ -30,11 +30,11 @@ class ToolCommand(BaseCommand, PlusAPIMixin): A class to handle tool repository related operations for CrewAI projects. """ - def __init__(self): + def __init__(self) -> None: BaseCommand.__init__(self) PlusAPIMixin.__init__(self, telemetry=self._telemetry) - def create(self, handle: str): + def create(self, handle: str) -> None: self._ensure_not_in_project() folder_name = handle.replace(" ", "_").replace("-", "_").lower() @@ -64,7 +64,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin): finally: os.chdir(old_directory) - def publish(self, is_public: bool, force: bool = False): + def publish(self, is_public: bool, force: bool = False) -> None: if not git.Repository().is_synced() and not force: console.print( "[bold red]Failed to publish tool.[/bold red]\n" @@ -137,7 +137,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin): style="bold green", ) - def install(self, handle: str): + def install(self, handle: str) -> None: self._print_current_organization() get_response = self.plus_api_client.get_tool(handle) @@ -180,7 +180,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin): settings.org_name = login_response_json["current_organization"]["name"] settings.dump() - def _add_package(self, tool_details: dict[str, Any]): + def _add_package(self, tool_details: dict[str, Any]) -> None: is_from_pypi = tool_details.get("source", None) == "pypi" tool_handle = tool_details["handle"] repository_handle = tool_details["repository"]["handle"] @@ -209,7 +209,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin): click.echo(add_package_result.stderr, err=True) raise SystemExit - def _ensure_not_in_project(self): + def _ensure_not_in_project(self) -> None: if os.path.isfile("./pyproject.toml"): console.print( "[bold red]Oops! It looks like you're inside a project.[/bold red]" diff --git a/lib/crewai/src/crewai/cli/utils.py b/lib/crewai/src/crewai/cli/utils.py index 041bc4e9d..b73f9f76b 100644 --- a/lib/crewai/src/crewai/cli/utils.py +++ b/lib/crewai/src/crewai/cli/utils.py @@ -5,7 +5,7 @@ import os from pathlib import Path import shutil import sys -from typing import Any, get_type_hints +from typing import Any, cast, get_type_hints import click from rich.console import Console @@ -23,7 +23,9 @@ if sys.version_info >= (3, 11): console = Console() -def copy_template(src, dst, name, class_name, folder_name): +def copy_template( + src: Path, dst: Path, name: str, class_name: str, folder_name: str +) -> None: """Copy a file from src to dst.""" with open(src, "r") as file: content = file.read() @@ -40,13 +42,13 @@ def copy_template(src, dst, name, class_name, folder_name): click.secho(f" - Created {dst}", fg="green") -def read_toml(file_path: str = "pyproject.toml"): +def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]: """Read the content of a TOML file and return it as a dictionary.""" with open(file_path, "rb") as f: return tomli.load(f) -def parse_toml(content): +def parse_toml(content: str) -> dict[str, Any]: if sys.version_info >= (3, 11): return tomllib.loads(content) return tomli.loads(content) @@ -103,7 +105,7 @@ def _get_project_attribute( ) except Exception as e: # Handle TOML decode errors for Python 3.11+ - if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError): # type: ignore + if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError): console.print( f"Error: {pyproject_path} is not a valid TOML file.", style="bold red" ) @@ -126,7 +128,7 @@ def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any: return reduce(dict.__getitem__, keys, data) -def fetch_and_json_env_file(env_file_path: str = ".env") -> dict: +def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]: """Fetch the environment variables from a .env file and return them as a dictionary.""" try: # Read the .env file @@ -150,7 +152,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict: return {} -def tree_copy(source, destination): +def tree_copy(source: Path, destination: Path) -> None: """Copies the entire directory structure from the source to the destination.""" for item in os.listdir(source): source_item = os.path.join(source, item) @@ -161,7 +163,7 @@ def tree_copy(source, destination): shutil.copy2(source_item, destination_item) -def tree_find_and_replace(directory, find, replace): +def tree_find_and_replace(directory: Path, find: str, replace: str) -> None: """Recursively searches through a directory, replacing a target string in both file contents and filenames with a specified replacement string. """ @@ -187,7 +189,7 @@ def tree_find_and_replace(directory, find, replace): os.rename(old_dirpath, new_dirpath) -def load_env_vars(folder_path): +def load_env_vars(folder_path: Path) -> dict[str, Any]: """ Loads environment variables from a .env file in the specified folder path. @@ -208,7 +210,9 @@ def load_env_vars(folder_path): return env_vars -def update_env_vars(env_vars, provider, model): +def update_env_vars( + env_vars: dict[str, Any], provider: str, model: str +) -> dict[str, Any] | None: """ Updates environment variables with the API key for the selected provider and model. @@ -220,15 +224,20 @@ def update_env_vars(env_vars, provider, model): Returns: - None """ - api_key_var = ENV_VARS.get( - provider, - [ - click.prompt( - f"Enter the environment variable name for your {provider.capitalize()} API key", - type=str, - ) - ], - )[0] + provider_config = cast( + list[str], + ENV_VARS.get( + provider, + [ + click.prompt( + f"Enter the environment variable name for your {provider.capitalize()} API key", + type=str, + ) + ], + ), + ) + + api_key_var = provider_config[0] if api_key_var not in env_vars: try: @@ -246,7 +255,7 @@ def update_env_vars(env_vars, provider, model): return env_vars -def write_env_file(folder_path, env_vars): +def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None: """ Writes environment variables to a .env file in the specified folder. @@ -342,18 +351,18 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]: return crew_instances -def get_crew_instance(module_attr) -> Crew | None: +def get_crew_instance(module_attr: Any) -> Crew | None: if ( callable(module_attr) and hasattr(module_attr, "is_crew_class") and module_attr.is_crew_class ): - return module_attr().crew() + return cast(Crew, module_attr().crew()) try: if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints( module_attr ).get("return") is Crew: - return module_attr() + return cast(Crew, module_attr()) except Exception: return None @@ -362,7 +371,7 @@ def get_crew_instance(module_attr) -> Crew | None: return None -def fetch_crews(module_attr) -> list[Crew]: +def fetch_crews(module_attr: Any) -> list[Crew]: crew_instances: list[Crew] = [] if crew_instance := get_crew_instance(module_attr): @@ -377,7 +386,7 @@ def fetch_crews(module_attr) -> list[Crew]: return crew_instances -def is_valid_tool(obj): +def is_valid_tool(obj: Any) -> bool: from crewai.tools.base_tool import Tool if isclass(obj): @@ -389,7 +398,7 @@ def is_valid_tool(obj): return isinstance(obj, Tool) -def extract_available_exports(dir_path: str = "src"): +def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]: """ Extract available tool classes from the project's __init__.py files. Only includes classes that inherit from BaseTool or functions decorated with @tool. @@ -419,7 +428,9 @@ def extract_available_exports(dir_path: str = "src"): raise SystemExit(1) from e -def build_env_with_tool_repository_credentials(repository_handle: str): +def build_env_with_tool_repository_credentials( + repository_handle: str, +) -> dict[str, Any]: repository_handle = repository_handle.upper().replace("-", "_") settings = Settings() @@ -472,7 +483,7 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]: sys.modules.pop("temp_module", None) -def _print_no_tools_warning(): +def _print_no_tools_warning() -> None: """ Display warning and usage instructions if no tools were found. """ diff --git a/lib/crewai/tests/cli/authentication/providers/test_okta.py b/lib/crewai/tests/cli/authentication/providers/test_okta.py index 5ceb441bf..5108b1bb6 100644 --- a/lib/crewai/tests/cli/authentication/providers/test_okta.py +++ b/lib/crewai/tests/cli/authentication/providers/test_okta.py @@ -37,6 +37,36 @@ class TestOktaProvider: provider = OktaProvider(settings) expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize" assert provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_custom_authorization_server_name(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": False, + "authorization_server_name": "my_auth_server_xxxAAA777" + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize" + assert provider.get_authorize_url() == expected_url + + def test_get_authorize_url_when_using_org_auth_server(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": True, + "authorization_server_name": None + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize" + assert provider.get_authorize_url() == expected_url def test_get_token_url(self): expected_url = "https://test-domain.okta.com/oauth2/default/v1/token" @@ -53,6 +83,36 @@ class TestOktaProvider: expected_url = "https://another-domain.okta.com/oauth2/default/v1/token" assert provider.get_token_url() == expected_url + def test_get_token_url_with_custom_authorization_server_name(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": False, + "authorization_server_name": "my_auth_server_xxxAAA777" + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token" + assert provider.get_token_url() == expected_url + + def test_get_token_url_when_using_org_auth_server(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": True, + "authorization_server_name": None + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/v1/token" + assert provider.get_token_url() == expected_url + def test_get_jwks_url(self): expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys" assert self.provider.get_jwks_url() == expected_url @@ -68,6 +128,36 @@ class TestOktaProvider: expected_url = "https://dev.okta.com/oauth2/default/v1/keys" assert provider.get_jwks_url() == expected_url + def test_get_jwks_url_with_custom_authorization_server_name(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": False, + "authorization_server_name": "my_auth_server_xxxAAA777" + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys" + assert provider.get_jwks_url() == expected_url + + def test_get_jwks_url_when_using_org_auth_server(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": True, + "authorization_server_name": None + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/v1/keys" + assert provider.get_jwks_url() == expected_url + def test_get_issuer(self): expected_issuer = "https://test-domain.okta.com/oauth2/default" assert self.provider.get_issuer() == expected_issuer @@ -83,6 +173,36 @@ class TestOktaProvider: expected_issuer = "https://prod.okta.com/oauth2/default" assert provider.get_issuer() == expected_issuer + def test_get_issuer_with_custom_authorization_server_name(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": False, + "authorization_server_name": "my_auth_server_xxxAAA777" + } + ) + provider = OktaProvider(settings) + expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777" + assert provider.get_issuer() == expected_issuer + + def test_get_issuer_when_using_org_auth_server(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": True, + "authorization_server_name": None + } + ) + provider = OktaProvider(settings) + expected_issuer = "https://test-domain.okta.com" + assert provider.get_issuer() == expected_issuer + def test_get_audience(self): assert self.provider.get_audience() == "test-audience" @@ -100,3 +220,38 @@ class TestOktaProvider: def test_get_client_id(self): assert self.provider.get_client_id() == "test-client-id" + + def test_get_required_fields(self): + assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"]) + + def test_oauth2_base_url(self): + assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default" + + def test_oauth2_base_url_with_custom_authorization_server_name(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": False, + "authorization_server_name": "my_auth_server_xxxAAA777" + } + ) + + provider = OktaProvider(settings) + assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777" + + def test_oauth2_base_url_when_using_org_auth_server(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": True, + "authorization_server_name": None + } + ) + provider = OktaProvider(settings) + assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2" \ No newline at end of file diff --git a/lib/crewai/tests/cli/enterprise/test_main.py b/lib/crewai/tests/cli/enterprise/test_main.py index 559aaaa14..e6be4e006 100644 --- a/lib/crewai/tests/cli/enterprise/test_main.py +++ b/lib/crewai/tests/cli/enterprise/test_main.py @@ -37,7 +37,8 @@ class TestEnterpriseConfigureCommand(unittest.TestCase): 'audience': 'test_audience', 'domain': 'test.domain.com', 'device_authorization_client_id': 'test_client_id', - 'provider': 'workos' + 'provider': 'workos', + 'extra': {} } mock_requests_get.return_value = mock_response @@ -60,11 +61,12 @@ class TestEnterpriseConfigureCommand(unittest.TestCase): ('oauth2_provider', 'workos'), ('oauth2_audience', 'test_audience'), ('oauth2_client_id', 'test_client_id'), - ('oauth2_domain', 'test.domain.com') + ('oauth2_domain', 'test.domain.com'), + ('oauth2_extra', {}) ] actual_calls = self.mock_settings_command.set.call_args_list - self.assertEqual(len(actual_calls), 5) + self.assertEqual(len(actual_calls), 6) for i, (key, value) in enumerate(expected_calls): call_args = actual_calls[i][0]