feat: Add token manager to encrypt access token and get and save tokens

This commit is contained in:
Eduardo Chiarotti
2024-08-21 19:32:10 -03:00
parent d4d7712164
commit bcc050b793
5 changed files with 88 additions and 22 deletions

View File

@@ -1,24 +1,26 @@
import time
import webbrowser
from typing import Any, Dict, Optional
from typing import Any, Dict
import requests
from rich.console import Console
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
from .utils import validate_token
from .utils import TokenManager, validate_token
console = Console()
class Authentication:
class AuthenticationCommand:
DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
def signup(self) -> Optional[Dict[str, Any]]:
def __init__(self):
self.token_manager = TokenManager()
def signup(self) -> None:
"""Sign up to CrewAI+"""
console.print("Signing Up to CrewAI+ \n", style="bold blue")
device_code_data = self._get_device_code()
self._display_auth_instructions(device_code_data)
@@ -29,7 +31,7 @@ class Authentication:
device_code_payload = {
"client_id": AUTH0_CLIENT_ID,
"scope": "openid profile email",
"scope": "openid",
"audience": "https://dev-jzsr0j8zs0atl5ha.us.auth0.com/api/v2/",
}
response = requests.post(url=self.DEVICE_CODE_URL, data=device_code_payload)
@@ -42,9 +44,7 @@ class Authentication:
console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token(
self, device_code_data: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
"""Poll the server for the token."""
token_payload = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
@@ -59,13 +59,10 @@ class Authentication:
if response.status_code == 200:
validate_token(token_data["id_token"])
# current_user = jwt.decode(
# token_data["id_token"],
# algorithms=ALGORITHMS,
# options={"verify_signature": False},
# )
expires_in = 360000 # Token expiration time in seconds
self.token_manager.save_tokens(token_data["access_token"], expires_in)
console.print("\nWelcome to CrewAI+ !!", style="green")
return token_data
return
if token_data["error"] not in ("authorization_pending", "slow_down"):
raise requests.HTTPError(token_data["error_description"])

View File

@@ -1,7 +1,13 @@
import json
import os
from datetime import datetime, timedelta
from typing import Optional
from auth0.authentication.token_verifier import (
AsymmetricSignatureVerifier,
TokenVerifier,
)
from cryptography.fernet import Fernet
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
@@ -19,3 +25,46 @@ def validate_token(id_token: str) -> None:
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
)
token_verifier.verify(id_token)
class TokenManager:
def __init__(self, file_path="tokens.enc"):
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self):
key_file = "secret.key"
if os.path.exists(key_file):
return open(key_file, "rb").read()
else:
key = Fernet.generate_key()
with open(key_file, "wb") as key_file:
key_file.write(key)
return key
def save_tokens(self, access_token, expires_in):
expiration_time = datetime.now() + timedelta(seconds=expires_in)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
with open(self.file_path, "wb") as file:
file.write(encrypted_data)
def get_token(self) -> Optional[str]:
if not os.path.exists(self.file_path):
return None
with open(self.file_path, "rb") as file:
encrypted_data = file.read()
decrypted_data = self.fernet.decrypt(encrypted_data)
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return data["access_token"]

View File

@@ -9,7 +9,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
from .authentication.main import Authentication
from .authentication.main import AuthenticationCommand
from .deploy.main import DeployCommand
from .evaluate_crew import evaluate_crew
from .replay_from_task import replay_task_command
@@ -17,8 +17,6 @@ from .reset_memories_command import reset_memories_command
from .run_crew import run_crew
from .train_crew import train_crew
deploy_cmd = DeployCommand()
@click.group()
def crewai():
@@ -181,7 +179,7 @@ def run():
@crewai.command()
def signup():
"""Sign Up/Login to CrewAI+."""
Authentication().signup()
AuthenticationCommand().signup()
# DEPLOY CREWAI+ COMMANDS
@@ -194,12 +192,14 @@ def deploy():
@deploy.command(name="create")
def deploy_create():
"""Create a Crew deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.create_crew()
@deploy.command(name="list")
def deploy_list():
"""List all deployments."""
deploy_cmd = DeployCommand()
deploy_cmd.list_crews()
@@ -207,6 +207,7 @@ def deploy_list():
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_push(uuid: Optional[str]):
"""Deploy the Crew."""
deploy_cmd = DeployCommand()
deploy_cmd.deploy(uuid=uuid)
@@ -214,6 +215,7 @@ def deploy_push(uuid: Optional[str]):
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deply_status(uuid: Optional[str]):
"""Get the status of a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.get_crew_status(uuid=uuid)
@@ -221,6 +223,7 @@ def deply_status(uuid: Optional[str]):
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_logs(uuid: Optional[str]):
"""Get the logs of a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.get_crew_logs(uuid=uuid)
@@ -228,6 +231,7 @@ def deploy_logs(uuid: Optional[str]):
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
def deploy_remove(uuid: Optional[str]):
"""Remove a deployment."""
deploy_cmd = DeployCommand()
deploy_cmd.remove_crew(uuid=uuid)

View File

@@ -22,8 +22,18 @@ class DeployCommand:
"""
Initialize the DeployCommand with project name and API client.
"""
try:
access_token = get_auth_token()
except Exception:
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
)
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
raise SystemExit
self.project_name = get_project_name()
self.client = CrewAPI(api_key=get_auth_token())
self.client = CrewAPI(api_key=access_token)
def _handle_error(self, json_response: Dict[str, Any]) -> None:
"""

View File

@@ -3,6 +3,8 @@ import subprocess
import tomllib
from ..authentication.utils import TokenManager
def get_git_remote_url() -> str:
"""Get the Git repository's remote URL."""
@@ -81,5 +83,9 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
return {}
def get_auth_token():
return "<token>"
def get_auth_token() -> str:
"""Get the authentication token."""
access_token = TokenManager().get_token()
if not access_token:
raise Exception()
return access_token