chore: apply ruff linting fixes to CLI module

fix: apply ruff fixes to CLI and update Okta provider test
This commit is contained in:
Greyson LaLonde
2025-09-19 19:55:55 -04:00
committed by GitHub
parent de5d3c3ad1
commit f4abc41235
18 changed files with 207 additions and 168 deletions

View File

@@ -1,5 +1,6 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class Auth0Provider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth/device/code"
@@ -14,13 +15,20 @@ class Auth0Provider(BaseProvider):
return f"https://{self._get_domain()}/"
def get_audience(self) -> str:
assert self.settings.audience is not None, "Audience is required"
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
assert self.settings.client_id is not None, "Client ID is required"
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def _get_domain(self) -> str:
assert self.settings.domain is not None, "Domain is required"
if self.settings.domain is None:
raise ValueError("Domain is required. Please set it in the configuration.")
return self.settings.domain

View File

@@ -1,30 +1,26 @@
from abc import ABC, abstractmethod
from crewai.cli.authentication.main import Oauth2Settings
class BaseProvider(ABC):
def __init__(self, settings: Oauth2Settings):
self.settings = settings
@abstractmethod
def get_authorize_url(self) -> str:
...
def get_authorize_url(self) -> str: ...
@abstractmethod
def get_token_url(self) -> str:
...
def get_token_url(self) -> str: ...
@abstractmethod
def get_jwks_url(self) -> str:
...
def get_jwks_url(self) -> str: ...
@abstractmethod
def get_issuer(self) -> str:
...
def get_issuer(self) -> str: ...
@abstractmethod
def get_audience(self) -> str:
...
def get_audience(self) -> str: ...
@abstractmethod
def get_client_id(self) -> str:
...
def get_client_id(self) -> str: ...

View File

@@ -1,5 +1,6 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class OktaProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self.settings.domain}/oauth2/default/v1/device/authorize"
@@ -14,9 +15,15 @@ class OktaProvider(BaseProvider):
return f"https://{self.settings.domain}/oauth2/default"
def get_audience(self) -> str:
assert self.settings.audience is not None
if self.settings.audience is None:
raise ValueError(
"Audience is required. Please set it in the configuration."
)
return self.settings.audience
def get_client_id(self) -> str:
assert self.settings.client_id is not None
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id

View File

@@ -1,5 +1,6 @@
from crewai.cli.authentication.providers.base_provider import BaseProvider
class WorkosProvider(BaseProvider):
def get_authorize_url(self) -> str:
return f"https://{self._get_domain()}/oauth2/device_authorization"
@@ -17,9 +18,13 @@ class WorkosProvider(BaseProvider):
return self.settings.audience or ""
def get_client_id(self) -> str:
assert self.settings.client_id is not None, "Client ID is required"
if self.settings.client_id is None:
raise ValueError(
"Client ID is required. Please set it in the configuration."
)
return self.settings.client_id
def _get_domain(self) -> str:
assert self.settings.domain is not None, "Domain is required"
if self.settings.domain is None:
raise ValueError("Domain is required. Please set it in the configuration.")
return self.settings.domain

View File

@@ -17,8 +17,6 @@ def validate_jwt_token(
missing required claims).
"""
decoded_token = None
try:
jwk_client = PyJWKClient(jwks_url)
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
@@ -26,7 +24,7 @@ def validate_jwt_token(
_unverified_decoded_token = jwt.decode(
jwt_token, options={"verify_signature": False}
)
decoded_token = jwt.decode(
return jwt.decode(
jwt_token,
signing_key.key,
algorithms=["RS256"],
@@ -40,23 +38,22 @@ def validate_jwt_token(
"require": ["exp", "iat", "iss", "aud", "sub"],
},
)
return decoded_token
except jwt.ExpiredSignatureError:
raise Exception("Token has expired.")
except jwt.InvalidAudienceError:
except jwt.ExpiredSignatureError as e:
raise Exception("Token has expired.") from e
except jwt.InvalidAudienceError as e:
actual_audience = _unverified_decoded_token.get("aud", "[no audience found]")
raise Exception(
f"Invalid token audience. Got: '{actual_audience}'. Expected: '{audience}'"
)
except jwt.InvalidIssuerError:
) from e
except jwt.InvalidIssuerError as e:
actual_issuer = _unverified_decoded_token.get("iss", "[no issuer found]")
raise Exception(
f"Invalid token issuer. Got: '{actual_issuer}'. Expected: '{issuer}'"
)
) from e
except jwt.MissingRequiredClaimError as e:
raise Exception(f"Token is missing required claims: {str(e)}")
raise Exception(f"Token is missing required claims: {e!s}") from e
except jwt.exceptions.PyJWKClientError as e:
raise Exception(f"JWKS or key processing error: {str(e)}")
raise Exception(f"JWKS or key processing error: {e!s}") from e
except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {str(e)}")
raise Exception(f"Invalid token: {e!s}") from e