feat: Add token manager to encrypt access token and get and save tokens

This commit is contained in:
Eduardo Chiarotti
2024-08-21 19:32:10 -03:00
parent d4d7712164
commit bcc050b793
5 changed files with 88 additions and 22 deletions

View File

@@ -1,24 +1,26 @@
import time import time
import webbrowser import webbrowser
from typing import Any, Dict, Optional from typing import Any, Dict
import requests import requests
from rich.console import Console from rich.console import Console
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
from .utils import validate_token from .utils import TokenManager, validate_token
console = Console() console = Console()
class Authentication: class AuthenticationCommand:
DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code" DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token" TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
def signup(self) -> Optional[Dict[str, Any]]: def __init__(self):
self.token_manager = TokenManager()
def signup(self) -> None:
"""Sign up to CrewAI+""" """Sign up to CrewAI+"""
console.print("Signing Up to CrewAI+ \n", style="bold blue") console.print("Signing Up to CrewAI+ \n", style="bold blue")
device_code_data = self._get_device_code() device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data) self._display_auth_instructions(device_code_data)
@@ -29,7 +31,7 @@ class Authentication:
device_code_payload = { device_code_payload = {
"client_id": AUTH0_CLIENT_ID, "client_id": AUTH0_CLIENT_ID,
"scope": "openid profile email", "scope": "openid",
"audience": "https://dev-jzsr0j8zs0atl5ha.us.auth0.com/api/v2/", "audience": "https://dev-jzsr0j8zs0atl5ha.us.auth0.com/api/v2/",
} }
response = requests.post(url=self.DEVICE_CODE_URL, data=device_code_payload) response = requests.post(url=self.DEVICE_CODE_URL, data=device_code_payload)
@@ -42,9 +44,7 @@ class Authentication:
console.print("2. Enter the following code: ", device_code_data["user_code"]) console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(device_code_data["verification_uri_complete"]) webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token( def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
self, device_code_data: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""Poll the server for the token.""" """Poll the server for the token."""
token_payload = { token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code", "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
@@ -59,13 +59,10 @@ class Authentication:
if response.status_code == 200: if response.status_code == 200:
validate_token(token_data["id_token"]) validate_token(token_data["id_token"])
# current_user = jwt.decode( expires_in = 360000 # Token expiration time in seconds
# token_data["id_token"], self.token_manager.save_tokens(token_data["access_token"], expires_in)
# algorithms=ALGORITHMS,
# options={"verify_signature": False},
# )
console.print("\nWelcome to CrewAI+ !!", style="green") console.print("\nWelcome to CrewAI+ !!", style="green")
return token_data return
if token_data["error"] not in ("authorization_pending", "slow_down"): if token_data["error"] not in ("authorization_pending", "slow_down"):
raise requests.HTTPError(token_data["error_description"]) raise requests.HTTPError(token_data["error_description"])

View File

@@ -1,7 +1,13 @@
import json
import os
from datetime import datetime, timedelta
from typing import Optional
from auth0.authentication.token_verifier import ( from auth0.authentication.token_verifier import (
AsymmetricSignatureVerifier, AsymmetricSignatureVerifier,
TokenVerifier, TokenVerifier,
) )
from cryptography.fernet import Fernet
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
@@ -19,3 +25,46 @@ def validate_token(id_token: str) -> None:
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
) )
token_verifier.verify(id_token) token_verifier.verify(id_token)
class TokenManager:
def __init__(self, file_path="tokens.enc"):
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self):
key_file = "secret.key"
if os.path.exists(key_file):
return open(key_file, "rb").read()
else:
key = Fernet.generate_key()
with open(key_file, "wb") as key_file:
key_file.write(key)
return key
def save_tokens(self, access_token, expires_in):
expiration_time = datetime.now() + timedelta(seconds=expires_in)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
with open(self.file_path, "wb") as file:
file.write(encrypted_data)
def get_token(self) -> Optional[str]:
if not os.path.exists(self.file_path):
return None
with open(self.file_path, "rb") as file:
encrypted_data = file.read()
decrypted_data = self.fernet.decrypt(encrypted_data)
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return data["access_token"]

View File

@@ -9,7 +9,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage, KickoffTaskOutputsSQLiteStorage,
) )
from .authentication.main import Authentication from .authentication.main import AuthenticationCommand
from .deploy.main import DeployCommand from .deploy.main import DeployCommand
from .evaluate_crew import evaluate_crew from .evaluate_crew import evaluate_crew
from .replay_from_task import replay_task_command from .replay_from_task import replay_task_command
@@ -17,8 +17,6 @@ from .reset_memories_command import reset_memories_command
from .run_crew import run_crew from .run_crew import run_crew
from .train_crew import train_crew from .train_crew import train_crew
deploy_cmd = DeployCommand()
@click.group() @click.group()
def crewai(): def crewai():
@@ -181,7 +179,7 @@ def run():
@crewai.command() @crewai.command()
def signup(): def signup():
"""Sign Up/Login to CrewAI+.""" """Sign Up/Login to CrewAI+."""
Authentication().signup() AuthenticationCommand().signup()
# DEPLOY CREWAI+ COMMANDS # DEPLOY CREWAI+ COMMANDS
@@ -194,12 +192,14 @@ def deploy():
@deploy.command(name="create") @deploy.command(name="create")
def deploy_create(): def deploy_create():
"""Create a Crew deployment.""" """Create a Crew deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.create_crew() deploy_cmd.create_crew()
@deploy.command(name="list") @deploy.command(name="list")
def deploy_list(): def deploy_list():
"""List all deployments.""" """List all deployments."""
deploy_cmd = DeployCommand()
deploy_cmd.list_crews() deploy_cmd.list_crews()
@@ -207,6 +207,7 @@ def deploy_list():
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_push(uuid: Optional[str]): def deploy_push(uuid: Optional[str]):
"""Deploy the Crew.""" """Deploy the Crew."""
deploy_cmd = DeployCommand()
deploy_cmd.deploy(uuid=uuid) deploy_cmd.deploy(uuid=uuid)
@@ -214,6 +215,7 @@ def deploy_push(uuid: Optional[str]):
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deply_status(uuid: Optional[str]): def deply_status(uuid: Optional[str]):
"""Get the status of a deployment.""" """Get the status of a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.get_crew_status(uuid=uuid) deploy_cmd.get_crew_status(uuid=uuid)
@@ -221,6 +223,7 @@ def deply_status(uuid: Optional[str]):
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_logs(uuid: Optional[str]): def deploy_logs(uuid: Optional[str]):
"""Get the logs of a deployment.""" """Get the logs of a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.get_crew_logs(uuid=uuid) deploy_cmd.get_crew_logs(uuid=uuid)
@@ -228,6 +231,7 @@ def deploy_logs(uuid: Optional[str]):
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter") @click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_remove(uuid: Optional[str]): def deploy_remove(uuid: Optional[str]):
"""Remove a deployment.""" """Remove a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.remove_crew(uuid=uuid) deploy_cmd.remove_crew(uuid=uuid)

View File

@@ -22,8 +22,18 @@ class DeployCommand:
""" """
Initialize the DeployCommand with project name and API client. Initialize the DeployCommand with project name and API client.
""" """
try:
access_token = get_auth_token()
except Exception:
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
)
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
raise SystemExit
self.project_name = get_project_name() self.project_name = get_project_name()
self.client = CrewAPI(api_key=get_auth_token()) self.client = CrewAPI(api_key=access_token)
def _handle_error(self, json_response: Dict[str, Any]) -> None: def _handle_error(self, json_response: Dict[str, Any]) -> None:
""" """

View File

@@ -3,6 +3,8 @@ import subprocess
import tomllib import tomllib
from ..authentication.utils import TokenManager
def get_git_remote_url() -> str: def get_git_remote_url() -> str:
"""Get the Git repository's remote URL.""" """Get the Git repository's remote URL."""
@@ -81,5 +83,9 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
return {} return {}
def get_auth_token(): def get_auth_token() -> str:
return "<token>" """Get the authentication token."""
access_token = TokenManager().get_token()
if not access_token:
raise Exception()
return access_token