mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: fetch and store more data about okta authorization server (#3894)
Some checks failed
Some checks failed
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||||
import webbrowser
|
import webbrowser
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -13,6 +13,8 @@ from crewai.cli.shared.token_manager import TokenManager
|
|||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
TOauth2Settings = TypeVar("TOauth2Settings", bound="Oauth2Settings")
|
||||||
|
|
||||||
|
|
||||||
class Oauth2Settings(BaseModel):
|
class Oauth2Settings(BaseModel):
|
||||||
provider: str = Field(
|
provider: str = Field(
|
||||||
@@ -28,9 +30,15 @@ class Oauth2Settings(BaseModel):
|
|||||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
extra: dict[str, Any] = Field(
|
||||||
|
description="Extra configuration for the OAuth2 provider.",
|
||||||
|
default={},
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_settings(cls):
|
def from_settings(cls: type[TOauth2Settings]) -> TOauth2Settings:
|
||||||
|
"""Create an Oauth2Settings instance from the CLI settings."""
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
@@ -38,12 +46,20 @@ class Oauth2Settings(BaseModel):
|
|||||||
domain=settings.oauth2_domain,
|
domain=settings.oauth2_domain,
|
||||||
client_id=settings.oauth2_client_id,
|
client_id=settings.oauth2_client_id,
|
||||||
audience=settings.oauth2_audience,
|
audience=settings.oauth2_audience,
|
||||||
|
extra=settings.oauth2_extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.cli.authentication.providers.base_provider import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
class ProviderFactory:
|
class ProviderFactory:
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_settings(cls, settings: Oauth2Settings | None = None):
|
def from_settings(
|
||||||
|
cls: type["ProviderFactory"], # noqa: UP037
|
||||||
|
settings: Oauth2Settings | None = None,
|
||||||
|
) -> "BaseProvider": # noqa: UP037
|
||||||
settings = settings or Oauth2Settings.from_settings()
|
settings = settings or Oauth2Settings.from_settings()
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
@@ -53,11 +69,11 @@ class ProviderFactory:
|
|||||||
)
|
)
|
||||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||||
|
|
||||||
return provider(settings)
|
return cast("BaseProvider", provider(settings))
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationCommand:
|
class AuthenticationCommand:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.token_manager = TokenManager()
|
self.token_manager = TokenManager()
|
||||||
self.oauth2_provider = ProviderFactory.from_settings()
|
self.oauth2_provider = ProviderFactory.from_settings()
|
||||||
|
|
||||||
@@ -84,7 +100,7 @@ class AuthenticationCommand:
|
|||||||
timeout=20,
|
timeout=20,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return cast(dict[str, Any], response.json())
|
||||||
|
|
||||||
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
def _display_auth_instructions(self, device_code_data: dict[str, str]) -> None:
|
||||||
"""Display the authentication instructions to the user."""
|
"""Display the authentication instructions to the user."""
|
||||||
|
|||||||
@@ -24,3 +24,7 @@ class BaseProvider(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_client_id(self) -> str: ...
|
def get_client_id(self) -> str: ...
|
||||||
|
|
||||||
|
def get_required_fields(self) -> list[str]:
|
||||||
|
"""Returns which provider-specific fields inside the "extra" dict will be required"""
|
||||||
|
return []
|
||||||
|
|||||||
@@ -3,16 +3,16 @@ from crewai.cli.authentication.providers.base_provider import BaseProvider
|
|||||||
|
|
||||||
class OktaProvider(BaseProvider):
|
class OktaProvider(BaseProvider):
|
||||||
def get_authorize_url(self) -> str:
|
def get_authorize_url(self) -> str:
|
||||||
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
|
return f"{self._oauth2_base_url()}/v1/device/authorize"
|
||||||
|
|
||||||
def get_token_url(self) -> str:
|
def get_token_url(self) -> str:
|
||||||
return f"https://{self.settings.domain}/oauth2/default/v1/token"
|
return f"{self._oauth2_base_url()}/v1/token"
|
||||||
|
|
||||||
def get_jwks_url(self) -> str:
|
def get_jwks_url(self) -> str:
|
||||||
return f"https://{self.settings.domain}/oauth2/default/v1/keys"
|
return f"{self._oauth2_base_url()}/v1/keys"
|
||||||
|
|
||||||
def get_issuer(self) -> str:
|
def get_issuer(self) -> str:
|
||||||
return f"https://{self.settings.domain}/oauth2/default"
|
return self._oauth2_base_url().removesuffix("/oauth2")
|
||||||
|
|
||||||
def get_audience(self) -> str:
|
def get_audience(self) -> str:
|
||||||
if self.settings.audience is None:
|
if self.settings.audience is None:
|
||||||
@@ -27,3 +27,16 @@ class OktaProvider(BaseProvider):
|
|||||||
"Client ID is required. Please set it in the configuration."
|
"Client ID is required. Please set it in the configuration."
|
||||||
)
|
)
|
||||||
return self.settings.client_id
|
return self.settings.client_id
|
||||||
|
|
||||||
|
def get_required_fields(self) -> list[str]:
|
||||||
|
return ["authorization_server_name", "using_org_auth_server"]
|
||||||
|
|
||||||
|
def _oauth2_base_url(self) -> str:
|
||||||
|
using_org_auth_server = self.settings.extra.get("using_org_auth_server", False)
|
||||||
|
|
||||||
|
if using_org_auth_server:
|
||||||
|
base_url = f"https://{self.settings.domain}/oauth2"
|
||||||
|
else:
|
||||||
|
base_url = f"https://{self.settings.domain}/oauth2/{self.settings.extra.get('authorization_server_name', 'default')}"
|
||||||
|
|
||||||
|
return f"{base_url}"
|
||||||
|
|||||||
@@ -11,18 +11,18 @@ console = Console()
|
|||||||
|
|
||||||
|
|
||||||
class BaseCommand:
|
class BaseCommand:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._telemetry = Telemetry()
|
self._telemetry = Telemetry()
|
||||||
self._telemetry.set_tracer()
|
self._telemetry.set_tracer()
|
||||||
|
|
||||||
|
|
||||||
class PlusAPIMixin:
|
class PlusAPIMixin:
|
||||||
def __init__(self, telemetry):
|
def __init__(self, telemetry: Telemetry) -> None:
|
||||||
try:
|
try:
|
||||||
telemetry.set_tracer()
|
telemetry.set_tracer()
|
||||||
self.plus_api_client = PlusAPI(api_key=get_auth_token())
|
self.plus_api_client = PlusAPI(api_key=get_auth_token())
|
||||||
except Exception:
|
except Exception:
|
||||||
self._deploy_signup_error_span = telemetry.deploy_signup_error_span()
|
telemetry.deploy_signup_error_span()
|
||||||
console.print(
|
console.print(
|
||||||
"Please sign up/login to CrewAI+ before using the CLI.",
|
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||||
style="bold red",
|
style="bold red",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import json
|
|||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -136,7 +137,12 @@ class Settings(BaseModel):
|
|||||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, config_path: Path | None = None, **data):
|
oauth2_extra: dict[str, Any] = Field(
|
||||||
|
description="Extra configuration for the OAuth2 provider.",
|
||||||
|
default={},
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, config_path: Path | None = None, **data: dict[str, Any]) -> None:
|
||||||
"""Load Settings from config path with fallback support"""
|
"""Load Settings from config path with fallback support"""
|
||||||
if config_path is None:
|
if config_path is None:
|
||||||
config_path = get_writable_config_path()
|
config_path = get_writable_config_path()
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from requests.exceptions import JSONDecodeError, RequestException
|
from requests.exceptions import JSONDecodeError, RequestException
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
|
from crewai.cli.authentication.main import Oauth2Settings, ProviderFactory
|
||||||
from crewai.cli.command import BaseCommand
|
from crewai.cli.command import BaseCommand
|
||||||
from crewai.cli.settings.main import SettingsCommand
|
from crewai.cli.settings.main import SettingsCommand
|
||||||
from crewai.cli.version import get_crewai_version
|
from crewai.cli.version import get_crewai_version
|
||||||
@@ -13,7 +14,7 @@ console = Console()
|
|||||||
|
|
||||||
|
|
||||||
class EnterpriseConfigureCommand(BaseCommand):
|
class EnterpriseConfigureCommand(BaseCommand):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.settings_command = SettingsCommand()
|
self.settings_command = SettingsCommand()
|
||||||
|
|
||||||
@@ -54,25 +55,12 @@ class EnterpriseConfigureCommand(BaseCommand):
|
|||||||
except JSONDecodeError as e:
|
except JSONDecodeError as e:
|
||||||
raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e
|
raise ValueError(f"Invalid JSON response from {oauth_endpoint}") from e
|
||||||
|
|
||||||
required_fields = [
|
self._validate_oauth_config(oauth_config)
|
||||||
"audience",
|
|
||||||
"domain",
|
|
||||||
"device_authorization_client_id",
|
|
||||||
"provider",
|
|
||||||
]
|
|
||||||
missing_fields = [
|
|
||||||
field for field in required_fields if field not in oauth_config
|
|
||||||
]
|
|
||||||
|
|
||||||
if missing_fields:
|
|
||||||
raise ValueError(
|
|
||||||
f"Missing required fields in OAuth2 configuration: {', '.join(missing_fields)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
console.print(
|
console.print(
|
||||||
"✅ Successfully retrieved OAuth2 configuration", style="green"
|
"✅ Successfully retrieved OAuth2 configuration", style="green"
|
||||||
)
|
)
|
||||||
return oauth_config
|
return cast(dict[str, Any], oauth_config)
|
||||||
|
|
||||||
except RequestException as e:
|
except RequestException as e:
|
||||||
raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
|
raise ValueError(f"Failed to connect to enterprise URL: {e!s}") from e
|
||||||
@@ -89,6 +77,7 @@ class EnterpriseConfigureCommand(BaseCommand):
|
|||||||
"oauth2_audience": oauth_config["audience"],
|
"oauth2_audience": oauth_config["audience"],
|
||||||
"oauth2_client_id": oauth_config["device_authorization_client_id"],
|
"oauth2_client_id": oauth_config["device_authorization_client_id"],
|
||||||
"oauth2_domain": oauth_config["domain"],
|
"oauth2_domain": oauth_config["domain"],
|
||||||
|
"oauth2_extra": oauth_config["extra"],
|
||||||
}
|
}
|
||||||
|
|
||||||
console.print("🔄 Updating local OAuth2 configuration...")
|
console.print("🔄 Updating local OAuth2 configuration...")
|
||||||
@@ -99,3 +88,38 @@ class EnterpriseConfigureCommand(BaseCommand):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to update OAuth2 settings: {e!s}") from e
|
raise ValueError(f"Failed to update OAuth2 settings: {e!s}") from e
|
||||||
|
|
||||||
|
def _validate_oauth_config(self, oauth_config: dict[str, Any]) -> None:
|
||||||
|
required_fields = [
|
||||||
|
"audience",
|
||||||
|
"domain",
|
||||||
|
"device_authorization_client_id",
|
||||||
|
"provider",
|
||||||
|
"extra",
|
||||||
|
]
|
||||||
|
|
||||||
|
missing_basic_fields = [
|
||||||
|
field for field in required_fields if field not in oauth_config
|
||||||
|
]
|
||||||
|
missing_provider_specific_fields = [
|
||||||
|
field
|
||||||
|
for field in self._get_provider_specific_fields(oauth_config["provider"])
|
||||||
|
if field not in oauth_config.get("extra", {})
|
||||||
|
]
|
||||||
|
|
||||||
|
if missing_basic_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing required fields in OAuth2 configuration: [{', '.join(missing_basic_fields)}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
if missing_provider_specific_fields:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing authentication provider required fields in OAuth2 configuration: [{', '.join(missing_provider_specific_fields)}] (Configured provider: '{oauth_config['provider']}')"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_provider_specific_fields(self, provider_name: str) -> list[str]:
|
||||||
|
provider = ProviderFactory.from_settings(
|
||||||
|
Oauth2Settings(provider=provider_name, client_id="dummy", domain="dummy")
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider.get_required_fields()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import subprocess
|
|||||||
|
|
||||||
|
|
||||||
class Repository:
|
class Repository:
|
||||||
def __init__(self, path="."):
|
def __init__(self, path: str = ".") -> None:
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
if not self.is_git_installed():
|
if not self.is_git_installed():
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from typing import Any
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -36,19 +37,21 @@ class PlusAPI:
|
|||||||
str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL
|
str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
)
|
)
|
||||||
|
|
||||||
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
def _make_request(
|
||||||
|
self, method: str, endpoint: str, **kwargs: Any
|
||||||
|
) -> requests.Response:
|
||||||
url = urljoin(self.base_url, endpoint)
|
url = urljoin(self.base_url, endpoint)
|
||||||
session = requests.Session()
|
session = requests.Session()
|
||||||
session.trust_env = False
|
session.trust_env = False
|
||||||
return session.request(method, url, headers=self.headers, **kwargs)
|
return session.request(method, url, headers=self.headers, **kwargs)
|
||||||
|
|
||||||
def login_to_tool_repository(self):
|
def login_to_tool_repository(self) -> requests.Response:
|
||||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
|
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
|
||||||
|
|
||||||
def get_tool(self, handle: str):
|
def get_tool(self, handle: str) -> requests.Response:
|
||||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||||
|
|
||||||
def get_agent(self, handle: str):
|
def get_agent(self, handle: str) -> requests.Response:
|
||||||
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
|
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
|
||||||
|
|
||||||
def publish_tool(
|
def publish_tool(
|
||||||
@@ -58,8 +61,8 @@ class PlusAPI:
|
|||||||
version: str,
|
version: str,
|
||||||
description: str | None,
|
description: str | None,
|
||||||
encoded_file: str,
|
encoded_file: str,
|
||||||
available_exports: list[str] | None = None,
|
available_exports: list[dict[str, Any]] | None = None,
|
||||||
):
|
) -> requests.Response:
|
||||||
params = {
|
params = {
|
||||||
"handle": handle,
|
"handle": handle,
|
||||||
"public": is_public,
|
"public": is_public,
|
||||||
@@ -111,13 +114,13 @@ class PlusAPI:
|
|||||||
def list_crews(self) -> requests.Response:
|
def list_crews(self) -> requests.Response:
|
||||||
return self._make_request("GET", self.CREWS_RESOURCE)
|
return self._make_request("GET", self.CREWS_RESOURCE)
|
||||||
|
|
||||||
def create_crew(self, payload) -> requests.Response:
|
def create_crew(self, payload: dict[str, Any]) -> requests.Response:
|
||||||
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
|
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
|
||||||
|
|
||||||
def get_organizations(self) -> requests.Response:
|
def get_organizations(self) -> requests.Response:
|
||||||
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
|
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
|
||||||
|
|
||||||
def initialize_trace_batch(self, payload) -> requests.Response:
|
def initialize_trace_batch(self, payload: dict[str, Any]) -> requests.Response:
|
||||||
return self._make_request(
|
return self._make_request(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.TRACING_RESOURCE}/batches",
|
f"{self.TRACING_RESOURCE}/batches",
|
||||||
@@ -125,14 +128,18 @@ class PlusAPI:
|
|||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
def initialize_ephemeral_trace_batch(self, payload) -> requests.Response:
|
def initialize_ephemeral_trace_batch(
|
||||||
|
self, payload: dict[str, Any]
|
||||||
|
) -> requests.Response:
|
||||||
return self._make_request(
|
return self._make_request(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
|
f"{self.EPHEMERAL_TRACING_RESOURCE}/batches",
|
||||||
json=payload,
|
json=payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_trace_events(self, trace_batch_id: str, payload) -> requests.Response:
|
def send_trace_events(
|
||||||
|
self, trace_batch_id: str, payload: dict[str, Any]
|
||||||
|
) -> requests.Response:
|
||||||
return self._make_request(
|
return self._make_request(
|
||||||
"POST",
|
"POST",
|
||||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/events",
|
||||||
@@ -141,7 +148,7 @@ class PlusAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def send_ephemeral_trace_events(
|
def send_ephemeral_trace_events(
|
||||||
self, trace_batch_id: str, payload
|
self, trace_batch_id: str, payload: dict[str, Any]
|
||||||
) -> requests.Response:
|
) -> requests.Response:
|
||||||
return self._make_request(
|
return self._make_request(
|
||||||
"POST",
|
"POST",
|
||||||
@@ -150,7 +157,9 @@ class PlusAPI:
|
|||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
def finalize_trace_batch(self, trace_batch_id: str, payload) -> requests.Response:
|
def finalize_trace_batch(
|
||||||
|
self, trace_batch_id: str, payload: dict[str, Any]
|
||||||
|
) -> requests.Response:
|
||||||
return self._make_request(
|
return self._make_request(
|
||||||
"PATCH",
|
"PATCH",
|
||||||
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
f"{self.TRACING_RESOURCE}/batches/{trace_batch_id}/finalize",
|
||||||
@@ -159,7 +168,7 @@ class PlusAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def finalize_ephemeral_trace_batch(
|
def finalize_ephemeral_trace_batch(
|
||||||
self, trace_batch_id: str, payload
|
self, trace_batch_id: str, payload: dict[str, Any]
|
||||||
) -> requests.Response:
|
) -> requests.Response:
|
||||||
return self._make_request(
|
return self._make_request(
|
||||||
"PATCH",
|
"PATCH",
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class SettingsCommand(BaseCommand):
|
|||||||
current_value = getattr(self.settings, field_name)
|
current_value = getattr(self.settings, field_name)
|
||||||
description = field_info.description or "No description available"
|
description = field_info.description or "No description available"
|
||||||
display_value = (
|
display_value = (
|
||||||
str(current_value) if current_value is not None else "Not set"
|
str(current_value) if current_value not in [None, {}] else "Not set"
|
||||||
)
|
)
|
||||||
|
|
||||||
table.add_row(field_name, display_value, description)
|
table.add_row(field_name, display_value, description)
|
||||||
|
|||||||
@@ -30,11 +30,11 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
|||||||
A class to handle tool repository related operations for CrewAI projects.
|
A class to handle tool repository related operations for CrewAI projects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
BaseCommand.__init__(self)
|
BaseCommand.__init__(self)
|
||||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||||
|
|
||||||
def create(self, handle: str):
|
def create(self, handle: str) -> None:
|
||||||
self._ensure_not_in_project()
|
self._ensure_not_in_project()
|
||||||
|
|
||||||
folder_name = handle.replace(" ", "_").replace("-", "_").lower()
|
folder_name = handle.replace(" ", "_").replace("-", "_").lower()
|
||||||
@@ -64,7 +64,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
|||||||
finally:
|
finally:
|
||||||
os.chdir(old_directory)
|
os.chdir(old_directory)
|
||||||
|
|
||||||
def publish(self, is_public: bool, force: bool = False):
|
def publish(self, is_public: bool, force: bool = False) -> None:
|
||||||
if not git.Repository().is_synced() and not force:
|
if not git.Repository().is_synced() and not force:
|
||||||
console.print(
|
console.print(
|
||||||
"[bold red]Failed to publish tool.[/bold red]\n"
|
"[bold red]Failed to publish tool.[/bold red]\n"
|
||||||
@@ -137,7 +137,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
|||||||
style="bold green",
|
style="bold green",
|
||||||
)
|
)
|
||||||
|
|
||||||
def install(self, handle: str):
|
def install(self, handle: str) -> None:
|
||||||
self._print_current_organization()
|
self._print_current_organization()
|
||||||
get_response = self.plus_api_client.get_tool(handle)
|
get_response = self.plus_api_client.get_tool(handle)
|
||||||
|
|
||||||
@@ -180,7 +180,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
|||||||
settings.org_name = login_response_json["current_organization"]["name"]
|
settings.org_name = login_response_json["current_organization"]["name"]
|
||||||
settings.dump()
|
settings.dump()
|
||||||
|
|
||||||
def _add_package(self, tool_details: dict[str, Any]):
|
def _add_package(self, tool_details: dict[str, Any]) -> None:
|
||||||
is_from_pypi = tool_details.get("source", None) == "pypi"
|
is_from_pypi = tool_details.get("source", None) == "pypi"
|
||||||
tool_handle = tool_details["handle"]
|
tool_handle = tool_details["handle"]
|
||||||
repository_handle = tool_details["repository"]["handle"]
|
repository_handle = tool_details["repository"]["handle"]
|
||||||
@@ -209,7 +209,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
|||||||
click.echo(add_package_result.stderr, err=True)
|
click.echo(add_package_result.stderr, err=True)
|
||||||
raise SystemExit
|
raise SystemExit
|
||||||
|
|
||||||
def _ensure_not_in_project(self):
|
def _ensure_not_in_project(self) -> None:
|
||||||
if os.path.isfile("./pyproject.toml"):
|
if os.path.isfile("./pyproject.toml"):
|
||||||
console.print(
|
console.print(
|
||||||
"[bold red]Oops! It looks like you're inside a project.[/bold red]"
|
"[bold red]Oops! It looks like you're inside a project.[/bold red]"
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, get_type_hints
|
from typing import Any, cast, get_type_hints
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@@ -23,7 +23,9 @@ if sys.version_info >= (3, 11):
|
|||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
def copy_template(src, dst, name, class_name, folder_name):
|
def copy_template(
|
||||||
|
src: Path, dst: Path, name: str, class_name: str, folder_name: str
|
||||||
|
) -> None:
|
||||||
"""Copy a file from src to dst."""
|
"""Copy a file from src to dst."""
|
||||||
with open(src, "r") as file:
|
with open(src, "r") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
@@ -40,13 +42,13 @@ def copy_template(src, dst, name, class_name, folder_name):
|
|||||||
click.secho(f" - Created {dst}", fg="green")
|
click.secho(f" - Created {dst}", fg="green")
|
||||||
|
|
||||||
|
|
||||||
def read_toml(file_path: str = "pyproject.toml"):
|
def read_toml(file_path: str = "pyproject.toml") -> dict[str, Any]:
|
||||||
"""Read the content of a TOML file and return it as a dictionary."""
|
"""Read the content of a TOML file and return it as a dictionary."""
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
return tomli.load(f)
|
return tomli.load(f)
|
||||||
|
|
||||||
|
|
||||||
def parse_toml(content):
|
def parse_toml(content: str) -> dict[str, Any]:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
return tomllib.loads(content)
|
return tomllib.loads(content)
|
||||||
return tomli.loads(content)
|
return tomli.loads(content)
|
||||||
@@ -103,7 +105,7 @@ def _get_project_attribute(
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle TOML decode errors for Python 3.11+
|
# Handle TOML decode errors for Python 3.11+
|
||||||
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError): # type: ignore
|
if sys.version_info >= (3, 11) and isinstance(e, tomllib.TOMLDecodeError):
|
||||||
console.print(
|
console.print(
|
||||||
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
|
f"Error: {pyproject_path} is not a valid TOML file.", style="bold red"
|
||||||
)
|
)
|
||||||
@@ -126,7 +128,7 @@ def _get_nested_value(data: dict[str, Any], keys: list[str]) -> Any:
|
|||||||
return reduce(dict.__getitem__, keys, data)
|
return reduce(dict.__getitem__, keys, data)
|
||||||
|
|
||||||
|
|
||||||
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict[str, Any]:
|
||||||
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
"""Fetch the environment variables from a .env file and return them as a dictionary."""
|
||||||
try:
|
try:
|
||||||
# Read the .env file
|
# Read the .env file
|
||||||
@@ -150,7 +152,7 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def tree_copy(source, destination):
|
def tree_copy(source: Path, destination: Path) -> None:
|
||||||
"""Copies the entire directory structure from the source to the destination."""
|
"""Copies the entire directory structure from the source to the destination."""
|
||||||
for item in os.listdir(source):
|
for item in os.listdir(source):
|
||||||
source_item = os.path.join(source, item)
|
source_item = os.path.join(source, item)
|
||||||
@@ -161,7 +163,7 @@ def tree_copy(source, destination):
|
|||||||
shutil.copy2(source_item, destination_item)
|
shutil.copy2(source_item, destination_item)
|
||||||
|
|
||||||
|
|
||||||
def tree_find_and_replace(directory, find, replace):
|
def tree_find_and_replace(directory: Path, find: str, replace: str) -> None:
|
||||||
"""Recursively searches through a directory, replacing a target string in
|
"""Recursively searches through a directory, replacing a target string in
|
||||||
both file contents and filenames with a specified replacement string.
|
both file contents and filenames with a specified replacement string.
|
||||||
"""
|
"""
|
||||||
@@ -187,7 +189,7 @@ def tree_find_and_replace(directory, find, replace):
|
|||||||
os.rename(old_dirpath, new_dirpath)
|
os.rename(old_dirpath, new_dirpath)
|
||||||
|
|
||||||
|
|
||||||
def load_env_vars(folder_path):
|
def load_env_vars(folder_path: Path) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Loads environment variables from a .env file in the specified folder path.
|
Loads environment variables from a .env file in the specified folder path.
|
||||||
|
|
||||||
@@ -208,7 +210,9 @@ def load_env_vars(folder_path):
|
|||||||
return env_vars
|
return env_vars
|
||||||
|
|
||||||
|
|
||||||
def update_env_vars(env_vars, provider, model):
|
def update_env_vars(
|
||||||
|
env_vars: dict[str, Any], provider: str, model: str
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
Updates environment variables with the API key for the selected provider and model.
|
Updates environment variables with the API key for the selected provider and model.
|
||||||
|
|
||||||
@@ -220,15 +224,20 @@ def update_env_vars(env_vars, provider, model):
|
|||||||
Returns:
|
Returns:
|
||||||
- None
|
- None
|
||||||
"""
|
"""
|
||||||
api_key_var = ENV_VARS.get(
|
provider_config = cast(
|
||||||
provider,
|
list[str],
|
||||||
[
|
ENV_VARS.get(
|
||||||
click.prompt(
|
provider,
|
||||||
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
[
|
||||||
type=str,
|
click.prompt(
|
||||||
)
|
f"Enter the environment variable name for your {provider.capitalize()} API key",
|
||||||
],
|
type=str,
|
||||||
)[0]
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key_var = provider_config[0]
|
||||||
|
|
||||||
if api_key_var not in env_vars:
|
if api_key_var not in env_vars:
|
||||||
try:
|
try:
|
||||||
@@ -246,7 +255,7 @@ def update_env_vars(env_vars, provider, model):
|
|||||||
return env_vars
|
return env_vars
|
||||||
|
|
||||||
|
|
||||||
def write_env_file(folder_path, env_vars):
|
def write_env_file(folder_path: Path, env_vars: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
Writes environment variables to a .env file in the specified folder.
|
Writes environment variables to a .env file in the specified folder.
|
||||||
|
|
||||||
@@ -342,18 +351,18 @@ def get_crews(crew_path: str = "crew.py", require: bool = False) -> list[Crew]:
|
|||||||
return crew_instances
|
return crew_instances
|
||||||
|
|
||||||
|
|
||||||
def get_crew_instance(module_attr) -> Crew | None:
|
def get_crew_instance(module_attr: Any) -> Crew | None:
|
||||||
if (
|
if (
|
||||||
callable(module_attr)
|
callable(module_attr)
|
||||||
and hasattr(module_attr, "is_crew_class")
|
and hasattr(module_attr, "is_crew_class")
|
||||||
and module_attr.is_crew_class
|
and module_attr.is_crew_class
|
||||||
):
|
):
|
||||||
return module_attr().crew()
|
return cast(Crew, module_attr().crew())
|
||||||
try:
|
try:
|
||||||
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
|
if (ismethod(module_attr) or isfunction(module_attr)) and get_type_hints(
|
||||||
module_attr
|
module_attr
|
||||||
).get("return") is Crew:
|
).get("return") is Crew:
|
||||||
return module_attr()
|
return cast(Crew, module_attr())
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -362,7 +371,7 @@ def get_crew_instance(module_attr) -> Crew | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def fetch_crews(module_attr) -> list[Crew]:
|
def fetch_crews(module_attr: Any) -> list[Crew]:
|
||||||
crew_instances: list[Crew] = []
|
crew_instances: list[Crew] = []
|
||||||
|
|
||||||
if crew_instance := get_crew_instance(module_attr):
|
if crew_instance := get_crew_instance(module_attr):
|
||||||
@@ -377,7 +386,7 @@ def fetch_crews(module_attr) -> list[Crew]:
|
|||||||
return crew_instances
|
return crew_instances
|
||||||
|
|
||||||
|
|
||||||
def is_valid_tool(obj):
|
def is_valid_tool(obj: Any) -> bool:
|
||||||
from crewai.tools.base_tool import Tool
|
from crewai.tools.base_tool import Tool
|
||||||
|
|
||||||
if isclass(obj):
|
if isclass(obj):
|
||||||
@@ -389,7 +398,7 @@ def is_valid_tool(obj):
|
|||||||
return isinstance(obj, Tool)
|
return isinstance(obj, Tool)
|
||||||
|
|
||||||
|
|
||||||
def extract_available_exports(dir_path: str = "src"):
|
def extract_available_exports(dir_path: str = "src") -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Extract available tool classes from the project's __init__.py files.
|
Extract available tool classes from the project's __init__.py files.
|
||||||
Only includes classes that inherit from BaseTool or functions decorated with @tool.
|
Only includes classes that inherit from BaseTool or functions decorated with @tool.
|
||||||
@@ -419,7 +428,9 @@ def extract_available_exports(dir_path: str = "src"):
|
|||||||
raise SystemExit(1) from e
|
raise SystemExit(1) from e
|
||||||
|
|
||||||
|
|
||||||
def build_env_with_tool_repository_credentials(repository_handle: str):
|
def build_env_with_tool_repository_credentials(
|
||||||
|
repository_handle: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
repository_handle = repository_handle.upper().replace("-", "_")
|
repository_handle = repository_handle.upper().replace("-", "_")
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
@@ -472,7 +483,7 @@ def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
|||||||
sys.modules.pop("temp_module", None)
|
sys.modules.pop("temp_module", None)
|
||||||
|
|
||||||
|
|
||||||
def _print_no_tools_warning():
|
def _print_no_tools_warning() -> None:
|
||||||
"""
|
"""
|
||||||
Display warning and usage instructions if no tools were found.
|
Display warning and usage instructions if no tools were found.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -37,6 +37,36 @@ class TestOktaProvider:
|
|||||||
provider = OktaProvider(settings)
|
provider = OktaProvider(settings)
|
||||||
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
|
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
|
||||||
assert provider.get_authorize_url() == expected_url
|
assert provider.get_authorize_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_authorize_url_with_custom_authorization_server_name(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="okta",
|
||||||
|
domain="test-domain.okta.com",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
extra={
|
||||||
|
"using_org_auth_server": False,
|
||||||
|
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = OktaProvider(settings)
|
||||||
|
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize"
|
||||||
|
assert provider.get_authorize_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_authorize_url_when_using_org_auth_server(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="okta",
|
||||||
|
domain="test-domain.okta.com",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
extra={
|
||||||
|
"using_org_auth_server": True,
|
||||||
|
"authorization_server_name": None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = OktaProvider(settings)
|
||||||
|
expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize"
|
||||||
|
assert provider.get_authorize_url() == expected_url
|
||||||
|
|
||||||
def test_get_token_url(self):
|
def test_get_token_url(self):
|
||||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
|
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
|
||||||
@@ -53,6 +83,36 @@ class TestOktaProvider:
|
|||||||
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
|
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
|
||||||
assert provider.get_token_url() == expected_url
|
assert provider.get_token_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_token_url_with_custom_authorization_server_name(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="okta",
|
||||||
|
domain="test-domain.okta.com",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
extra={
|
||||||
|
"using_org_auth_server": False,
|
||||||
|
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = OktaProvider(settings)
|
||||||
|
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token"
|
||||||
|
assert provider.get_token_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_token_url_when_using_org_auth_server(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="okta",
|
||||||
|
domain="test-domain.okta.com",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
extra={
|
||||||
|
"using_org_auth_server": True,
|
||||||
|
"authorization_server_name": None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = OktaProvider(settings)
|
||||||
|
expected_url = "https://test-domain.okta.com/oauth2/v1/token"
|
||||||
|
assert provider.get_token_url() == expected_url
|
||||||
|
|
||||||
def test_get_jwks_url(self):
|
def test_get_jwks_url(self):
|
||||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
|
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
|
||||||
assert self.provider.get_jwks_url() == expected_url
|
assert self.provider.get_jwks_url() == expected_url
|
||||||
@@ -68,6 +128,36 @@ class TestOktaProvider:
|
|||||||
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
|
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
|
||||||
assert provider.get_jwks_url() == expected_url
|
assert provider.get_jwks_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_jwks_url_with_custom_authorization_server_name(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="okta",
|
||||||
|
domain="test-domain.okta.com",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
extra={
|
||||||
|
"using_org_auth_server": False,
|
||||||
|
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = OktaProvider(settings)
|
||||||
|
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys"
|
||||||
|
assert provider.get_jwks_url() == expected_url
|
||||||
|
|
||||||
|
def test_get_jwks_url_when_using_org_auth_server(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="okta",
|
||||||
|
domain="test-domain.okta.com",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
extra={
|
||||||
|
"using_org_auth_server": True,
|
||||||
|
"authorization_server_name": None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = OktaProvider(settings)
|
||||||
|
expected_url = "https://test-domain.okta.com/oauth2/v1/keys"
|
||||||
|
assert provider.get_jwks_url() == expected_url
|
||||||
|
|
||||||
def test_get_issuer(self):
|
def test_get_issuer(self):
|
||||||
expected_issuer = "https://test-domain.okta.com/oauth2/default"
|
expected_issuer = "https://test-domain.okta.com/oauth2/default"
|
||||||
assert self.provider.get_issuer() == expected_issuer
|
assert self.provider.get_issuer() == expected_issuer
|
||||||
@@ -83,6 +173,36 @@ class TestOktaProvider:
|
|||||||
expected_issuer = "https://prod.okta.com/oauth2/default"
|
expected_issuer = "https://prod.okta.com/oauth2/default"
|
||||||
assert provider.get_issuer() == expected_issuer
|
assert provider.get_issuer() == expected_issuer
|
||||||
|
|
||||||
|
def test_get_issuer_with_custom_authorization_server_name(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="okta",
|
||||||
|
domain="test-domain.okta.com",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
extra={
|
||||||
|
"using_org_auth_server": False,
|
||||||
|
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = OktaProvider(settings)
|
||||||
|
expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
|
||||||
|
assert provider.get_issuer() == expected_issuer
|
||||||
|
|
||||||
|
def test_get_issuer_when_using_org_auth_server(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="okta",
|
||||||
|
domain="test-domain.okta.com",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
extra={
|
||||||
|
"using_org_auth_server": True,
|
||||||
|
"authorization_server_name": None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = OktaProvider(settings)
|
||||||
|
expected_issuer = "https://test-domain.okta.com"
|
||||||
|
assert provider.get_issuer() == expected_issuer
|
||||||
|
|
||||||
def test_get_audience(self):
|
def test_get_audience(self):
|
||||||
assert self.provider.get_audience() == "test-audience"
|
assert self.provider.get_audience() == "test-audience"
|
||||||
|
|
||||||
@@ -100,3 +220,38 @@ class TestOktaProvider:
|
|||||||
|
|
||||||
def test_get_client_id(self):
|
def test_get_client_id(self):
|
||||||
assert self.provider.get_client_id() == "test-client-id"
|
assert self.provider.get_client_id() == "test-client-id"
|
||||||
|
|
||||||
|
def test_get_required_fields(self):
|
||||||
|
assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"])
|
||||||
|
|
||||||
|
def test_oauth2_base_url(self):
|
||||||
|
assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default"
|
||||||
|
|
||||||
|
def test_oauth2_base_url_with_custom_authorization_server_name(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="okta",
|
||||||
|
domain="test-domain.okta.com",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
extra={
|
||||||
|
"using_org_auth_server": False,
|
||||||
|
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = OktaProvider(settings)
|
||||||
|
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
|
||||||
|
|
||||||
|
def test_oauth2_base_url_when_using_org_auth_server(self):
|
||||||
|
settings = Oauth2Settings(
|
||||||
|
provider="okta",
|
||||||
|
domain="test-domain.okta.com",
|
||||||
|
client_id="test-client-id",
|
||||||
|
audience=None,
|
||||||
|
extra={
|
||||||
|
"using_org_auth_server": True,
|
||||||
|
"authorization_server_name": None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
provider = OktaProvider(settings)
|
||||||
|
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2"
|
||||||
@@ -37,7 +37,8 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
|
|||||||
'audience': 'test_audience',
|
'audience': 'test_audience',
|
||||||
'domain': 'test.domain.com',
|
'domain': 'test.domain.com',
|
||||||
'device_authorization_client_id': 'test_client_id',
|
'device_authorization_client_id': 'test_client_id',
|
||||||
'provider': 'workos'
|
'provider': 'workos',
|
||||||
|
'extra': {}
|
||||||
}
|
}
|
||||||
mock_requests_get.return_value = mock_response
|
mock_requests_get.return_value = mock_response
|
||||||
|
|
||||||
@@ -60,11 +61,12 @@ class TestEnterpriseConfigureCommand(unittest.TestCase):
|
|||||||
('oauth2_provider', 'workos'),
|
('oauth2_provider', 'workos'),
|
||||||
('oauth2_audience', 'test_audience'),
|
('oauth2_audience', 'test_audience'),
|
||||||
('oauth2_client_id', 'test_client_id'),
|
('oauth2_client_id', 'test_client_id'),
|
||||||
('oauth2_domain', 'test.domain.com')
|
('oauth2_domain', 'test.domain.com'),
|
||||||
|
('oauth2_extra', {})
|
||||||
]
|
]
|
||||||
|
|
||||||
actual_calls = self.mock_settings_command.set.call_args_list
|
actual_calls = self.mock_settings_command.set.call_args_list
|
||||||
self.assertEqual(len(actual_calls), 5)
|
self.assertEqual(len(actual_calls), 6)
|
||||||
|
|
||||||
for i, (key, value) in enumerate(expected_calls):
|
for i, (key, value) in enumerate(expected_calls):
|
||||||
call_args = actual_calls[i][0]
|
call_args = actual_calls[i][0]
|
||||||
|
|||||||
Reference in New Issue
Block a user