mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-07 10:12:38 +00:00
Rewrite auth to use PKCE flow
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -1,9 +1,22 @@
|
|||||||
from .utils import TokenManager
|
from .utils import TokenManager, get_auth_token_with_refresh_token
|
||||||
|
|
||||||
|
|
||||||
def get_auth_token() -> str:
|
def get_auth_token() -> str:
|
||||||
"""Get the authentication token."""
|
"""Get the authentication token. Uses refresh token to fetch a new token if current one is expired."""
|
||||||
access_token = TokenManager().get_token()
|
access_token = TokenManager().get_token("access_token")
|
||||||
|
refresh_token = TokenManager().get_token("refresh_token")
|
||||||
|
|
||||||
|
# Token could be expired, so we use the refresh token to fetch a new one.
|
||||||
|
# Skip if refresh token is not available.
|
||||||
|
if not access_token and refresh_token:
|
||||||
|
data = get_auth_token_with_refresh_token(refresh_token)
|
||||||
|
access_token = data.get("access_token")
|
||||||
|
refresh_token = data.get("refresh_token")
|
||||||
|
|
||||||
|
if access_token and refresh_token:
|
||||||
|
TokenManager().save_access_token(access_token, data["expires_in"])
|
||||||
|
TokenManager().save_refresh_token(refresh_token)
|
||||||
|
|
||||||
if not access_token:
|
if not access_token:
|
||||||
raise Exception()
|
raise Exception("Access token could not be obtained. Please sign in again.")
|
||||||
return access_token
|
return access_token
|
||||||
|
|||||||
@@ -5,28 +5,100 @@ from datetime import datetime, timedelta
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from auth0.authentication.token_verifier import (
|
import jwt
|
||||||
AsymmetricSignatureVerifier,
|
import requests
|
||||||
TokenVerifier,
|
|
||||||
)
|
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
|
from jwt import PyJWKClient
|
||||||
|
|
||||||
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
|
from .constants import WORKOS_CLIENT_ID, WORKOS_DOMAIN, WORKOS_ENVIRONMENT_ID
|
||||||
|
|
||||||
|
|
||||||
def validate_token(id_token: str) -> None:
|
def get_auth_token_with_refresh_token(refresh_token: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Verify the token and its precedence
|
Get an access token using a refresh token.
|
||||||
|
|
||||||
:param id_token:
|
:param refresh_token: The refresh token to use.
|
||||||
|
:return: A dictionary containing the access token, its expiration time, and a new refresh token, or an empty dictionary if the attempt to get a new access token failed.
|
||||||
"""
|
"""
|
||||||
jwks_url = f"https://{AUTH0_DOMAIN}/.well-known/jwks.json"
|
|
||||||
issuer = f"https://{AUTH0_DOMAIN}/"
|
response = requests.post(
|
||||||
signature_verifier = AsymmetricSignatureVerifier(jwks_url)
|
WORKOS_TOKEN_URL,
|
||||||
token_verifier = TokenVerifier(
|
data={
|
||||||
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refresh_token,
|
||||||
|
"client_id": WORKOS_CLIENT_ID,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
token_verifier.verify(id_token)
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
try:
|
||||||
|
validate_token(data.get("access_token"))
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"access_token": data.get("access_token"),
|
||||||
|
"refresh_token": data.get("refresh_token"),
|
||||||
|
"expires_in": data.get("expires_in"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def validate_token(jwt_token: str, token_type: str = "access_token") -> dict:
|
||||||
|
"""
|
||||||
|
Verify the token's signature and claims using PyJWT.
|
||||||
|
|
||||||
|
:param jwt_token: The JWT (JWS) string to validate.
|
||||||
|
:return: The decoded token.
|
||||||
|
:raises Exception: If the token is invalid for any reason (e.g., signature mismatch,
|
||||||
|
expired, incorrect issuer/audience, JWKS fetching error,
|
||||||
|
missing required claims).
|
||||||
|
"""
|
||||||
|
|
||||||
|
supported_audiences = {
|
||||||
|
"access_token": WORKOS_ENVIRONMENT_ID,
|
||||||
|
"id_token": WORKOS_CLIENT_ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
jwks_url = f"https://{WORKOS_DOMAIN}/oauth2/jwks"
|
||||||
|
expected_issuer = f"https://{WORKOS_DOMAIN}"
|
||||||
|
expected_audience = supported_audiences[token_type]
|
||||||
|
decoded_token = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
jwk_client = PyJWKClient(jwks_url)
|
||||||
|
signing_key = jwk_client.get_signing_key_from_jwt(jwt_token)
|
||||||
|
|
||||||
|
decoded_token = jwt.decode(
|
||||||
|
jwt_token,
|
||||||
|
signing_key.key,
|
||||||
|
algorithms=["RS256"],
|
||||||
|
audience=expected_audience,
|
||||||
|
issuer=expected_issuer,
|
||||||
|
options={
|
||||||
|
"verify_signature": True,
|
||||||
|
"verify_exp": True,
|
||||||
|
"verify_nbf": True,
|
||||||
|
"verify_iat": True,
|
||||||
|
"require": ["exp", "iat", "iss", "aud", "sub"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return decoded_token
|
||||||
|
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
raise Exception("Token has expired.")
|
||||||
|
except jwt.InvalidAudienceError:
|
||||||
|
raise Exception(f"Invalid token audience. Expected: '{expected_audience}'")
|
||||||
|
except jwt.InvalidIssuerError:
|
||||||
|
raise Exception(f"Invalid token issuer. Expected: '{expected_issuer}'")
|
||||||
|
except jwt.MissingRequiredClaimError as e:
|
||||||
|
raise Exception(f"Token is missing required claims: {str(e)}")
|
||||||
|
except jwt.exceptions.PyJWKClientError as e:
|
||||||
|
raise Exception(f"JWKS or key processing error: {str(e)}")
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
raise Exception(f"Invalid token: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
class TokenManager:
|
class TokenManager:
|
||||||
@@ -56,37 +128,43 @@ class TokenManager:
|
|||||||
self.save_secure_file(key_filename, new_key)
|
self.save_secure_file(key_filename, new_key)
|
||||||
return new_key
|
return new_key
|
||||||
|
|
||||||
def save_tokens(self, access_token: str, expires_in: int) -> None:
|
def save_access_token(self, access_token: str, expires_in: int) -> None:
|
||||||
"""
|
"""
|
||||||
Save the access token and its expiration time.
|
Save the access token and its expiration time.
|
||||||
|
|
||||||
:param access_token: The access token to save.
|
:param access_token: The access token to save.
|
||||||
:param expires_in: The expiration time of the access token in seconds.
|
:param expires_in: The expiration time of the access token in seconds.
|
||||||
"""
|
"""
|
||||||
expiration_time = datetime.now() + timedelta(seconds=expires_in)
|
self._save_token("access_token", access_token, expires_in)
|
||||||
data = {
|
|
||||||
"access_token": access_token,
|
|
||||||
"expiration": expiration_time.isoformat(),
|
|
||||||
}
|
|
||||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
|
||||||
self.save_secure_file(self.file_path, encrypted_data)
|
|
||||||
|
|
||||||
def get_token(self) -> Optional[str]:
|
def save_refresh_token(self, refresh_token: str) -> None:
|
||||||
"""
|
"""
|
||||||
Get the access token if it is valid and not expired.
|
Save the refresh token and its expiration time.
|
||||||
|
|
||||||
:return: The access token if valid and not expired, otherwise None.
|
:param refresh_token: The refresh token to save.
|
||||||
|
|
||||||
|
Refresh tokens don't have an expiration time, so the expiration time is set to 100 years from now.
|
||||||
|
"""
|
||||||
|
self._save_token("refresh_token", refresh_token, 3153600000)
|
||||||
|
|
||||||
|
def get_token(self, token_type: str = "access_token") -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the specified token if it exists and is valid (not expired).
|
||||||
|
|
||||||
|
:return: The specified token if it exists and hasn't expired, otherwise None.
|
||||||
"""
|
"""
|
||||||
encrypted_data = self.read_secure_file(self.file_path)
|
encrypted_data = self.read_secure_file(self.file_path)
|
||||||
|
|
||||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||||
data = json.loads(decrypted_data)
|
all_tokens = json.loads(decrypted_data)
|
||||||
|
if not (token_data := all_tokens.get(token_type)):
|
||||||
|
return None
|
||||||
|
|
||||||
expiration = datetime.fromisoformat(data["expiration"])
|
expiration = datetime.fromisoformat(token_data["expiration"])
|
||||||
if expiration <= datetime.now():
|
if expiration <= datetime.now():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return data["access_token"]
|
return token_data["value"]
|
||||||
|
|
||||||
def get_secure_storage_path(self) -> Path:
|
def get_secure_storage_path(self) -> Path:
|
||||||
"""
|
"""
|
||||||
@@ -142,3 +220,28 @@ class TokenManager:
|
|||||||
|
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
|
def _save_token(self, token_type: str, token: str, expires_in: int) -> None:
|
||||||
|
"""
|
||||||
|
Save the token and its expiration time, updating the existing token file.
|
||||||
|
"""
|
||||||
|
all_tokens = {}
|
||||||
|
raw_existing_data = self.read_secure_file(self.file_path)
|
||||||
|
|
||||||
|
if raw_existing_data:
|
||||||
|
try:
|
||||||
|
decrypted_data = self.fernet.decrypt(raw_existing_data)
|
||||||
|
all_tokens = json.loads(decrypted_data.decode())
|
||||||
|
except Exception:
|
||||||
|
print("Error decrypting existing token file. Creating new file.")
|
||||||
|
all_tokens = {}
|
||||||
|
|
||||||
|
expiration_time = datetime.now() + timedelta(seconds=expires_in)
|
||||||
|
|
||||||
|
all_tokens[token_type] = {
|
||||||
|
"value": token,
|
||||||
|
"expiration": expiration_time.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
updated_encrypted_data = self.fernet.encrypt(json.dumps(all_tokens).encode())
|
||||||
|
self.save_secure_file(self.file_path, updated_encrypted_data)
|
||||||
|
|||||||
Reference in New Issue
Block a user