From bcc050b793531c40ac9bbc6d1b8d1c5ed10cdbbf Mon Sep 17 00:00:00 2001 From: Eduardo Chiarotti Date: Wed, 21 Aug 2024 19:32:10 -0300 Subject: [PATCH] feat: Add token manager to encrypt access token and get and save tokens --- src/crewai/cli/authentication/main.py | 27 +++++++------- src/crewai/cli/authentication/utils.py | 49 ++++++++++++++++++++++++++ src/crewai/cli/cli.py | 12 ++++--- src/crewai/cli/deploy/main.py | 12 ++++++- src/crewai/cli/deploy/utils.py | 10 ++++-- 5 files changed, 88 insertions(+), 22 deletions(-) diff --git a/src/crewai/cli/authentication/main.py b/src/crewai/cli/authentication/main.py index e18b9bbcf..77a2787a5 100644 --- a/src/crewai/cli/authentication/main.py +++ b/src/crewai/cli/authentication/main.py @@ -1,24 +1,26 @@ import time import webbrowser -from typing import Any, Dict, Optional +from typing import Any, Dict import requests from rich.console import Console from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN -from .utils import validate_token +from .utils import TokenManager, validate_token console = Console() -class Authentication: +class AuthenticationCommand: DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code" TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token" - def signup(self) -> Optional[Dict[str, Any]]: + def __init__(self): + self.token_manager = TokenManager() + + def signup(self) -> None: """Sign up to CrewAI+""" console.print("Signing Up to CrewAI+ \n", style="bold blue") - device_code_data = self._get_device_code() self._display_auth_instructions(device_code_data) @@ -29,7 +31,7 @@ class Authentication: device_code_payload = { "client_id": AUTH0_CLIENT_ID, - "scope": "openid profile email", + "scope": "openid", "audience": "https://dev-jzsr0j8zs0atl5ha.us.auth0.com/api/v2/", } response = requests.post(url=self.DEVICE_CODE_URL, data=device_code_payload) @@ -42,9 +44,7 @@ class Authentication: console.print("2. Enter the following code: ", device_code_data["user_code"]) webbrowser.open(device_code_data["verification_uri_complete"]) - def _poll_for_token( - self, device_code_data: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: + def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None: """Poll the server for the token.""" token_payload = { "grant_type": "urn:ietf:params:oauth:grant-type:device_code", @@ -59,13 +59,10 @@ class Authentication: if response.status_code == 200: validate_token(token_data["id_token"]) - # current_user = jwt.decode( - # token_data["id_token"], - # algorithms=ALGORITHMS, - # options={"verify_signature": False}, - # ) + expires_in = 360000 # Token expiration time in seconds + self.token_manager.save_tokens(token_data["access_token"], expires_in) console.print("\nWelcome to CrewAI+ !!", style="green") - return token_data + return if token_data["error"] not in ("authorization_pending", "slow_down"): raise requests.HTTPError(token_data["error_description"]) diff --git a/src/crewai/cli/authentication/utils.py b/src/crewai/cli/authentication/utils.py index f4d3420ce..0ff807046 100644 --- a/src/crewai/cli/authentication/utils.py +++ b/src/crewai/cli/authentication/utils.py @@ -1,7 +1,13 @@ +import json +import os +from datetime import datetime, timedelta +from typing import Optional + from auth0.authentication.token_verifier import ( AsymmetricSignatureVerifier, TokenVerifier, ) +from cryptography.fernet import Fernet from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN @@ -19,3 +25,46 @@ def validate_token(id_token: str) -> None: signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID ) token_verifier.verify(id_token) + + +class TokenManager: + def __init__(self, file_path="tokens.enc"): + self.file_path = file_path + self.key = self._get_or_create_key() + self.fernet = Fernet(self.key) + + def _get_or_create_key(self): + key_file = "secret.key" + if os.path.exists(key_file): + return open(key_file, "rb").read() + else: + key = Fernet.generate_key() + with open(key_file, "wb") as key_file: + key_file.write(key) + return key + + def save_tokens(self, access_token, expires_in): + expiration_time = datetime.now() + timedelta(seconds=expires_in) + data = { + "access_token": access_token, + "expiration": expiration_time.isoformat(), + } + encrypted_data = self.fernet.encrypt(json.dumps(data).encode()) + with open(self.file_path, "wb") as file: + file.write(encrypted_data) + + def get_token(self) -> Optional[str]: + if not os.path.exists(self.file_path): + return None + + with open(self.file_path, "rb") as file: + encrypted_data = file.read() + + decrypted_data = self.fernet.decrypt(encrypted_data) + data = json.loads(decrypted_data) + + expiration = datetime.fromisoformat(data["expiration"]) + if expiration <= datetime.now(): + return None + + return data["access_token"] diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index dec978b7a..54cda8dbf 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -9,7 +9,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import ( KickoffTaskOutputsSQLiteStorage, ) -from .authentication.main import Authentication +from .authentication.main import AuthenticationCommand from .deploy.main import DeployCommand from .evaluate_crew import evaluate_crew from .replay_from_task import replay_task_command @@ -17,8 +17,6 @@ from .reset_memories_command import reset_memories_command from .run_crew import run_crew from .train_crew import train_crew -deploy_cmd = DeployCommand() - @click.group() def crewai(): @@ -181,7 +179,7 @@ def run(): @crewai.command() def signup(): """Sign Up/Login to CrewAI+.""" - Authentication().signup() + AuthenticationCommand().signup() # DEPLOY CREWAI+ COMMANDS @@ -194,12 +192,14 @@ def deploy(): @deploy.command(name="create") def deploy_create(): """Create a Crew deployment.""" + deploy_cmd = DeployCommand() deploy_cmd.create_crew() @deploy.command(name="list") def deploy_list(): """List all deployments.""" + deploy_cmd = DeployCommand() deploy_cmd.list_crews() @@ -207,6 +207,7 @@ def deploy_list(): @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") def deploy_push(uuid: Optional[str]): """Deploy the Crew.""" + deploy_cmd = DeployCommand() deploy_cmd.deploy(uuid=uuid) @@ -214,6 +215,7 @@ def deploy_push(uuid: Optional[str]): @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") def deply_status(uuid: Optional[str]): """Get the status of a deployment.""" + deploy_cmd = DeployCommand() deploy_cmd.get_crew_status(uuid=uuid) @@ -221,6 +223,7 @@ def deply_status(uuid: Optional[str]): @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") def deploy_logs(uuid: Optional[str]): """Get the logs of a deployment.""" + deploy_cmd = DeployCommand() deploy_cmd.get_crew_logs(uuid=uuid) @@ -228,6 +231,7 @@ def deploy_logs(uuid: Optional[str]): @click.option("-u", "--uuid", type=str, help="Crew UUID parameter") def deploy_remove(uuid: Optional[str]): """Remove a deployment.""" + deploy_cmd = DeployCommand() deploy_cmd.remove_crew(uuid=uuid) diff --git a/src/crewai/cli/deploy/main.py b/src/crewai/cli/deploy/main.py index ca54c3f80..d67e1cdc8 100644 --- a/src/crewai/cli/deploy/main.py +++ b/src/crewai/cli/deploy/main.py @@ -22,8 +22,18 @@ class DeployCommand: """ Initialize the DeployCommand with project name and API client. """ + try: + access_token = get_auth_token() + except Exception: + console.print( + "Please sign up/login to CrewAI+ before using the CLI.", + style="bold red", + ) + console.print("Run 'crewai signup' to sign up/login.", style="bold green") + raise SystemExit + self.project_name = get_project_name() - self.client = CrewAPI(api_key=get_auth_token()) + self.client = CrewAPI(api_key=access_token) def _handle_error(self, json_response: Dict[str, Any]) -> None: """ diff --git a/src/crewai/cli/deploy/utils.py b/src/crewai/cli/deploy/utils.py index 8dbc09bd5..53cf7dc58 100644 --- a/src/crewai/cli/deploy/utils.py +++ b/src/crewai/cli/deploy/utils.py @@ -3,6 +3,8 @@ import subprocess import tomllib +from ..authentication.utils import TokenManager + def get_git_remote_url() -> str: """Get the Git repository's remote URL.""" @@ -81,5 +83,9 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict: return {} -def get_auth_token(): - return "" +def get_auth_token() -> str: + """Get the authentication token.""" + access_token = TokenManager().get_token() + if not access_token: + raise Exception() + return access_token