diff --git a/src/crewai/cli/authentication/main.py b/src/crewai/cli/authentication/main.py index 2fd9c632b..72f176655 100644 --- a/src/crewai/cli/authentication/main.py +++ b/src/crewai/cli/authentication/main.py @@ -1,4 +1,7 @@ -import time +import base64 +import hashlib +import secrets +import textwrap import webbrowser from typing import Any, Dict @@ -7,87 +10,216 @@ from rich.console import Console from crewai.cli.tools.main import ToolCommand -from .constants import AUTH0_AUDIENCE, AUTH0_CLIENT_ID, AUTH0_DOMAIN +from .constants import ( + WORKOS_AUTHORIZE_URL, + WORKOS_CLIENT_ID, + WORKOS_DOMAIN, + WORKOS_TOKEN_URL, +) from .utils import TokenManager, validate_token console = Console() +import socket +from urllib.parse import parse_qs, urlparse + class AuthenticationCommand: - DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code" - TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token" + CODE_VERIFIER = secrets.token_urlsafe(64) + CODE_CHALLENGE = ( + base64.urlsafe_b64encode(hashlib.sha256(CODE_VERIFIER.encode()).digest()) + .rstrip(b"=") + .decode("utf-8") + ) + NONCE = secrets.token_hex(6) + STATE = secrets.token_hex(9) + SOCKET_HOST = "0.0.0.0" + SOCKET_PORT = 49152 def __init__(self): self.token_manager = TokenManager() + self.auth_url = self._get_auth_url() def login(self) -> None: """Login or Sign Up to CrewAI Enterprise""" - return self._poll_for_token(device_code_data) + console.print("Signing in to CrewAI enterprise... \n", style="bold blue") - def _get_device_code(self) -> Dict[str, Any]: - """Get the device code to authenticate the user.""" - - device_code_payload = { - "client_id": AUTH0_CLIENT_ID, - "scope": "openid", - "audience": AUTH0_AUDIENCE, - } - response = requests.post( - url=self.DEVICE_CODE_URL, data=device_code_payload, timeout=20 + # 1. Get the auth URL. Upon successful authentication, browser will redirect back to CLI with the 'code' parameter. + console.print( + f"1. Navigate to [bold blue][link={self.auth_url}]this link.[/link][/bold blue] (it should open automatically in a few seconds...)", + style="bold", ) - response.raise_for_status() + webbrowser.open(self.auth_url) + + # 2. Listen for the auth response from the browser, and upon receiving the 'code' parameter, authenticate the user. + redirect_url_params = self._listen_for_auth_response() + console.print( + "2. Login successful. Retrieving your [bold blue]access tokens[/bold blue]...", + style="bold", + ) + auth_response = self._authenticate(redirect_url_params) + + # 3. Validate the JWT token signature, extract the access and refresh tokens and save them to the token manager. + access_token, refresh_token, user_info = self._validate_and_extract_tokens( + auth_response + ) + self.token_manager.save_access_token(access_token, auth_response["expires_in"]) + self.token_manager.save_refresh_token(refresh_token) + + # 4. Sign in to the tool repository. + console.print( + "3. All good. Now signing you in to [bold blue]tool repository[/bold blue]...", + style="bold", + ) + self._sign_in_to_tool_repository() + + # 5. Wrap up. + console.print( + f"4. Done! You are now signed in to CrewAI enterprise. Welcome, [bold cyan]{user_info.get('name')}[/bold cyan].", + style="bold green", + ) + return None + + def _get_auth_url(self) -> str: + return ( + f"{WORKOS_AUTHORIZE_URL}?" + f"response_type=code&" + f"client_id={WORKOS_CLIENT_ID}&" + f"redirect_uri=http://localhost:{self.SOCKET_PORT}&" + f"scope=openid+profile+email+offline_access&" + f"code_challenge={self.CODE_CHALLENGE}&" + f"code_challenge_method=S256&" + f"nonce={self.NONCE}&" + f"state={self.STATE}" + ) + + def _listen_for_auth_response(self) -> dict[str, str]: + """ + Listen for the authentication response from the browser. + + Returns: + dict[str, str]: The URL parameters passed in the querystring of the redirect URL. + """ + + redirect_url_params = {} + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: + server_socket.bind((self.SOCKET_HOST, self.SOCKET_PORT)) + server_socket.listen(1) + console.print("> Waiting for browser login and approval...", style="yellow") + + conn, addr = server_socket.accept() + with conn: + request = conn.recv(1024).decode("utf-8") + + # Extract the request line (first line of the HTTP request) + request_line = request.splitlines()[0] + method, path, _ = request_line.split() + + # Parse the URL path to get query string parameters + parsed_url = urlparse(path) + redirect_url_params = parse_qs(parsed_url.query) + + # Convert values from lists to single values if appropriate + redirect_url_params = { + k: v[0] if len(v) == 1 else v + for k, v in redirect_url_params.items() + } + + # Prepare the HTTP response with success message and JS that attempts to close the tab. + html_body = self._html_response_body() + http_response = f"HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {len(html_body.encode('utf-8'))}\r\nConnection: close\r\n\r\n{html_body}" + conn.sendall(http_response.encode("utf-8")) + + server_socket.close() + console.print("> Response received. Proceeding to login...", style="green") + + return redirect_url_params + + def _authenticate(self, params) -> dict[str, str]: + response = requests.post( + WORKOS_TOKEN_URL, + data={ + "grant_type": "authorization_code", + "client_id": WORKOS_CLIENT_ID, + "code": params["code"], + "redirect_uri": f"http://localhost:{self.SOCKET_PORT}", + "code_verifier": self.CODE_VERIFIER, + }, + ) + + if response.status_code != 200: + console.print( + f"❌ Failed to sign in to CrewAI enterprise. \nRun [bold]crewai login[/bold] and try logging in again.\n", + style="red", + ) + raise SystemExit + return response.json() - def _display_auth_instructions(self, device_code_data: Dict[str, str]) -> None: - """Display the authentication instructions to the user.""" - console.print("1. Navigate to: ", device_code_data["verification_uri_complete"]) - console.print("2. Enter the following code: ", device_code_data["user_code"]) - webbrowser.open(device_code_data["verification_uri_complete"]) + def _validate_and_extract_tokens( + self, response_dict: dict[str, str] + ) -> [str, str, dict[str, str]]: + user_info = {} + try: + validate_token(response_dict["access_token"]) + user_info = validate_token(response_dict["id_token"], "id_token") + except Exception as e: + console.print( + f"❌ Failure validating JWT token signature, login failed. \nRun [bold]crewai login[/bold] to try logging in again.\n\n Error: {e}", + style="red", + ) + raise SystemExit - def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None: - """Poll the server for the token.""" - token_payload = { - "grant_type": "urn:ietf:params:oauth:grant-type:device_code", - "device_code": device_code_data["device_code"], - "client_id": AUTH0_CLIENT_ID, - } + return response_dict["access_token"], response_dict["refresh_token"], user_info - attempts = 0 - while True and attempts < 5: - response = requests.post(self.TOKEN_URL, data=token_payload, timeout=30) - token_data = response.json() + def _sign_in_to_tool_repository(self) -> None: + try: + ToolCommand().login() + except Exception as e: + console.print( + "\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.", + style="yellow", + ) + console.print( + "Other features will work normally, but you may experience limitations " + "with downloading and publishing tools." + "\nRun [bold]crewai login[/bold] to try logging in again.\n", + style="yellow", + ) + console.print(f"Error: {e}", style="red") - if response.status_code == 200: - validate_token(token_data["id_token"]) - expires_in = 360000 # Token expiration time in seconds - self.token_manager.save_tokens(token_data["access_token"], expires_in) - - try: - ToolCommand().login() - except Exception: - console.print( - "\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.", - style="yellow", - ) - console.print( - "Other features will work normally, but you may experience limitations " - "with downloading and publishing tools." - "\nRun [bold]crewai login[/bold] to try logging in again.\n", - style="yellow", - ) - - console.print( - "\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n" - ) - return - - if token_data["error"] not in ("authorization_pending", "slow_down"): - raise requests.HTTPError(token_data["error_description"]) - - time.sleep(device_code_data["interval"]) - attempts += 1 - - console.print( - "Timeout: Failed to get the token. Please try again.", style="bold red" - ) + def _html_response_body(self) -> str: + html_body = textwrap.dedent("""\ + + + + + + Authentication Successful + + + + +
+ +

Authentication Successful!

+

Your authentication through CLI was successful. This browser tab should close automatically.

+

If it doesn't close in a few seconds, you may close it manually.

+
+ + """) + return html_body diff --git a/src/crewai/cli/authentication/token.py b/src/crewai/cli/authentication/token.py index 30a33b4ba..566812375 100644 --- a/src/crewai/cli/authentication/token.py +++ b/src/crewai/cli/authentication/token.py @@ -1,9 +1,22 @@ -from .utils import TokenManager +from .utils import TokenManager, get_auth_token_with_refresh_token def get_auth_token() -> str: - """Get the authentication token.""" - access_token = TokenManager().get_token() + """Get the authentication token. Uses refresh token to fetch a new token if current one is expired.""" + access_token = TokenManager().get_token("access_token") + refresh_token = TokenManager().get_token("refresh_token") + + # Token could be expired, so we use the refresh token to fetch a new one. + # Skip if refresh token is not available. + if not access_token and refresh_token: + data = get_auth_token_with_refresh_token(refresh_token) + access_token = data.get("access_token") + refresh_token = data.get("refresh_token") + + if access_token and refresh_token: + TokenManager().save_access_token(access_token, data["expires_in"]) + TokenManager().save_refresh_token(refresh_token) + if not access_token: - raise Exception() + raise Exception("Access token could not be obtained. Please sign in again.") return access_token diff --git a/src/crewai/cli/authentication/utils.py b/src/crewai/cli/authentication/utils.py index 2f5fc183f..715918362 100644 --- a/src/crewai/cli/authentication/utils.py +++ b/src/crewai/cli/authentication/utils.py @@ -5,28 +5,100 @@ from datetime import datetime, timedelta from pathlib import Path from typing import Optional -from auth0.authentication.token_verifier import ( - AsymmetricSignatureVerifier, - TokenVerifier, -) +import jwt +import requests from cryptography.fernet import Fernet +from jwt import PyJWKClient -from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN +from .constants import WORKOS_CLIENT_ID, WORKOS_DOMAIN, WORKOS_ENVIRONMENT_ID -def validate_token(id_token: str) -> None: +def get_auth_token_with_refresh_token(refresh_token: str) -> dict: """ - Verify the token and its precedence + Get an access token using a refresh token. - :param id_token: + :param refresh_token: The refresh token to use. + :return: A dictionary containing the access token, its expiration time, and a new refresh token, or an empty dictionary if the attempt to get a new access token failed. """ - jwks_url = f"https://{AUTH0_DOMAIN}/.well-known/jwks.json" - issuer = f"https://{AUTH0_DOMAIN}/" - signature_verifier = AsymmetricSignatureVerifier(jwks_url) - token_verifier = TokenVerifier( - signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID + + response = requests.post( + WORKOS_TOKEN_URL, + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": WORKOS_CLIENT_ID, + }, ) - token_verifier.verify(id_token) + + if response.status_code != 200: + return {} + + data = response.json() + try: + validate_token(data.get("access_token")) + except Exception: + return {} + + return { + "access_token": data.get("access_token"), + "refresh_token": data.get("refresh_token"), + "expires_in": data.get("expires_in"), + } + + +def validate_token(jwt_token: str, token_type: str = "access_token") -> dict: + """ + Verify the token's signature and claims using PyJWT. + + :param jwt_token: The JWT (JWS) string to validate. + :return: The decoded token. + :raises Exception: If the token is invalid for any reason (e.g., signature mismatch, + expired, incorrect issuer/audience, JWKS fetching error, + missing required claims). + """ + + supported_audiences = { + "access_token": WORKOS_ENVIRONMENT_ID, + "id_token": WORKOS_CLIENT_ID, + } + + jwks_url = f"https://{WORKOS_DOMAIN}/oauth2/jwks" + expected_issuer = f"https://{WORKOS_DOMAIN}" + expected_audience = supported_audiences[token_type] + decoded_token = None + + try: + jwk_client = PyJWKClient(jwks_url) + signing_key = jwk_client.get_signing_key_from_jwt(jwt_token) + + decoded_token = jwt.decode( + jwt_token, + signing_key.key, + algorithms=["RS256"], + audience=expected_audience, + issuer=expected_issuer, + options={ + "verify_signature": True, + "verify_exp": True, + "verify_nbf": True, + "verify_iat": True, + "require": ["exp", "iat", "iss", "aud", "sub"], + }, + ) + return decoded_token + + except jwt.ExpiredSignatureError: + raise Exception("Token has expired.") + except jwt.InvalidAudienceError: + raise Exception(f"Invalid token audience. Expected: '{expected_audience}'") + except jwt.InvalidIssuerError: + raise Exception(f"Invalid token issuer. Expected: '{expected_issuer}'") + except jwt.MissingRequiredClaimError as e: + raise Exception(f"Token is missing required claims: {str(e)}") + except jwt.exceptions.PyJWKClientError as e: + raise Exception(f"JWKS or key processing error: {str(e)}") + except jwt.InvalidTokenError as e: + raise Exception(f"Invalid token: {str(e)}") class TokenManager: @@ -56,37 +128,43 @@ class TokenManager: self.save_secure_file(key_filename, new_key) return new_key - def save_tokens(self, access_token: str, expires_in: int) -> None: + def save_access_token(self, access_token: str, expires_in: int) -> None: """ Save the access token and its expiration time. :param access_token: The access token to save. :param expires_in: The expiration time of the access token in seconds. """ - expiration_time = datetime.now() + timedelta(seconds=expires_in) - data = { - "access_token": access_token, - "expiration": expiration_time.isoformat(), - } - encrypted_data = self.fernet.encrypt(json.dumps(data).encode()) - self.save_secure_file(self.file_path, encrypted_data) + self._save_token("access_token", access_token, expires_in) - def get_token(self) -> Optional[str]: + def save_refresh_token(self, refresh_token: str) -> None: """ - Get the access token if it is valid and not expired. + Save the refresh token and its expiration time. - :return: The access token if valid and not expired, otherwise None. + :param refresh_token: The refresh token to save. + + Refresh tokens don't have an expiration time, so the expiration time is set to 100 years from now. + """ + self._save_token("refresh_token", refresh_token, 3153600000) + + def get_token(self, token_type: str = "access_token") -> Optional[str]: + """ + Get the specified token if it exists and is valid (not expired). + + :return: The specified token if it exists and hasn't expired, otherwise None. """ encrypted_data = self.read_secure_file(self.file_path) decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore - data = json.loads(decrypted_data) + all_tokens = json.loads(decrypted_data) + if not (token_data := all_tokens.get(token_type)): + return None - expiration = datetime.fromisoformat(data["expiration"]) + expiration = datetime.fromisoformat(token_data["expiration"]) if expiration <= datetime.now(): return None - return data["access_token"] + return token_data["value"] def get_secure_storage_path(self) -> Path: """ @@ -142,3 +220,28 @@ class TokenManager: with open(file_path, "rb") as f: return f.read() + + def _save_token(self, token_type: str, token: str, expires_in: int) -> None: + """ + Save the token and its expiration time, updating the existing token file. + """ + all_tokens = {} + raw_existing_data = self.read_secure_file(self.file_path) + + if raw_existing_data: + try: + decrypted_data = self.fernet.decrypt(raw_existing_data) + all_tokens = json.loads(decrypted_data.decode()) + except Exception: + print("Error decrypting existing token file. Creating new file.") + all_tokens = {} + + expiration_time = datetime.now() + timedelta(seconds=expires_in) + + all_tokens[token_type] = { + "value": token, + "expiration": expiration_time.isoformat(), + } + + updated_encrypted_data = self.fernet.encrypt(json.dumps(all_tokens).encode()) + self.save_secure_file(self.file_path, updated_encrypted_data)