mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
feat: Add token manager to encrypt access token and get and save tokens
This commit is contained in:
@@ -1,24 +1,26 @@
|
||||
import time
|
||||
import webbrowser
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import requests
|
||||
from rich.console import Console
|
||||
|
||||
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
|
||||
from .utils import validate_token
|
||||
from .utils import TokenManager, validate_token
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class Authentication:
|
||||
class AuthenticationCommand:
|
||||
DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
|
||||
TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
|
||||
|
||||
def signup(self) -> Optional[Dict[str, Any]]:
|
||||
def __init__(self):
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
def signup(self) -> None:
|
||||
"""Sign up to CrewAI+"""
|
||||
console.print("Signing Up to CrewAI+ \n", style="bold blue")
|
||||
|
||||
device_code_data = self._get_device_code()
|
||||
self._display_auth_instructions(device_code_data)
|
||||
|
||||
@@ -29,7 +31,7 @@ class Authentication:
|
||||
|
||||
device_code_payload = {
|
||||
"client_id": AUTH0_CLIENT_ID,
|
||||
"scope": "openid profile email",
|
||||
"scope": "openid",
|
||||
"audience": "https://dev-jzsr0j8zs0atl5ha.us.auth0.com/api/v2/",
|
||||
}
|
||||
response = requests.post(url=self.DEVICE_CODE_URL, data=device_code_payload)
|
||||
@@ -42,9 +44,7 @@ class Authentication:
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
|
||||
def _poll_for_token(
|
||||
self, device_code_data: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
|
||||
"""Poll the server for the token."""
|
||||
token_payload = {
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
@@ -59,13 +59,10 @@ class Authentication:
|
||||
|
||||
if response.status_code == 200:
|
||||
validate_token(token_data["id_token"])
|
||||
# current_user = jwt.decode(
|
||||
# token_data["id_token"],
|
||||
# algorithms=ALGORITHMS,
|
||||
# options={"verify_signature": False},
|
||||
# )
|
||||
expires_in = 360000 # Token expiration time in seconds
|
||||
self.token_manager.save_tokens(token_data["access_token"], expires_in)
|
||||
console.print("\nWelcome to CrewAI+ !!", style="green")
|
||||
return token_data
|
||||
return
|
||||
|
||||
if token_data["error"] not in ("authorization_pending", "slow_down"):
|
||||
raise requests.HTTPError(token_data["error_description"])
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from auth0.authentication.token_verifier import (
|
||||
AsymmetricSignatureVerifier,
|
||||
TokenVerifier,
|
||||
)
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
|
||||
|
||||
@@ -19,3 +25,46 @@ def validate_token(id_token: str) -> None:
|
||||
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
|
||||
)
|
||||
token_verifier.verify(id_token)
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path="tokens.enc"):
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
def _get_or_create_key(self):
|
||||
key_file = "secret.key"
|
||||
if os.path.exists(key_file):
|
||||
return open(key_file, "rb").read()
|
||||
else:
|
||||
key = Fernet.generate_key()
|
||||
with open(key_file, "wb") as key_file:
|
||||
key_file.write(key)
|
||||
return key
|
||||
|
||||
def save_tokens(self, access_token, expires_in):
|
||||
expiration_time = datetime.now() + timedelta(seconds=expires_in)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
with open(self.file_path, "wb") as file:
|
||||
file.write(encrypted_data)
|
||||
|
||||
def get_token(self) -> Optional[str]:
|
||||
if not os.path.exists(self.file_path):
|
||||
return None
|
||||
|
||||
with open(self.file_path, "rb") as file:
|
||||
encrypted_data = file.read()
|
||||
|
||||
decrypted_data = self.fernet.decrypt(encrypted_data)
|
||||
data = json.loads(decrypted_data)
|
||||
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return data["access_token"]
|
||||
|
||||
@@ -9,7 +9,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
|
||||
from .authentication.main import Authentication
|
||||
from .authentication.main import AuthenticationCommand
|
||||
from .deploy.main import DeployCommand
|
||||
from .evaluate_crew import evaluate_crew
|
||||
from .replay_from_task import replay_task_command
|
||||
@@ -17,8 +17,6 @@ from .reset_memories_command import reset_memories_command
|
||||
from .run_crew import run_crew
|
||||
from .train_crew import train_crew
|
||||
|
||||
deploy_cmd = DeployCommand()
|
||||
|
||||
|
||||
@click.group()
|
||||
def crewai():
|
||||
@@ -181,7 +179,7 @@ def run():
|
||||
@crewai.command()
|
||||
def signup():
|
||||
"""Sign Up/Login to CrewAI+."""
|
||||
Authentication().signup()
|
||||
AuthenticationCommand().signup()
|
||||
|
||||
|
||||
# DEPLOY CREWAI+ COMMANDS
|
||||
@@ -194,12 +192,14 @@ def deploy():
|
||||
@deploy.command(name="create")
|
||||
def deploy_create():
|
||||
"""Create a Crew deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.create_crew()
|
||||
|
||||
|
||||
@deploy.command(name="list")
|
||||
def deploy_list():
|
||||
"""List all deployments."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.list_crews()
|
||||
|
||||
|
||||
@@ -207,6 +207,7 @@ def deploy_list():
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_push(uuid: Optional[str]):
|
||||
"""Deploy the Crew."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.deploy(uuid=uuid)
|
||||
|
||||
|
||||
@@ -214,6 +215,7 @@ def deploy_push(uuid: Optional[str]):
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deply_status(uuid: Optional[str]):
|
||||
"""Get the status of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_status(uuid=uuid)
|
||||
|
||||
|
||||
@@ -221,6 +223,7 @@ def deply_status(uuid: Optional[str]):
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_logs(uuid: Optional[str]):
|
||||
"""Get the logs of a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||
|
||||
|
||||
@@ -228,6 +231,7 @@ def deploy_logs(uuid: Optional[str]):
|
||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||
def deploy_remove(uuid: Optional[str]):
|
||||
"""Remove a deployment."""
|
||||
deploy_cmd = DeployCommand()
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
|
||||
|
||||
|
||||
@@ -22,8 +22,18 @@ class DeployCommand:
|
||||
"""
|
||||
Initialize the DeployCommand with project name and API client.
|
||||
"""
|
||||
try:
|
||||
access_token = get_auth_token()
|
||||
except Exception:
|
||||
console.print(
|
||||
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||
style="bold red",
|
||||
)
|
||||
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
|
||||
raise SystemExit
|
||||
|
||||
self.project_name = get_project_name()
|
||||
self.client = CrewAPI(api_key=get_auth_token())
|
||||
self.client = CrewAPI(api_key=access_token)
|
||||
|
||||
def _handle_error(self, json_response: Dict[str, Any]) -> None:
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,8 @@ import subprocess
|
||||
|
||||
import tomllib
|
||||
|
||||
from ..authentication.utils import TokenManager
|
||||
|
||||
|
||||
def get_git_remote_url() -> str:
|
||||
"""Get the Git repository's remote URL."""
|
||||
@@ -81,5 +83,9 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
def get_auth_token():
|
||||
return "<token>"
|
||||
def get_auth_token() -> str:
|
||||
"""Get the authentication token."""
|
||||
access_token = TokenManager().get_token()
|
||||
if not access_token:
|
||||
raise Exception()
|
||||
return access_token
|
||||
|
||||
Reference in New Issue
Block a user