mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
feat: Add token manager to encrypt access token and get and save tokens
This commit is contained in:
@@ -1,24 +1,26 @@
|
|||||||
import time
|
import time
|
||||||
import webbrowser
|
import webbrowser
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
|
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
|
||||||
from .utils import validate_token
|
from .utils import TokenManager, validate_token
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
class Authentication:
|
class AuthenticationCommand:
|
||||||
DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
|
DEVICE_CODE_URL = f"https://{AUTH0_DOMAIN}/oauth/device/code"
|
||||||
TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
|
TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token"
|
||||||
|
|
||||||
def signup(self) -> Optional[Dict[str, Any]]:
|
def __init__(self):
|
||||||
|
self.token_manager = TokenManager()
|
||||||
|
|
||||||
|
def signup(self) -> None:
|
||||||
"""Sign up to CrewAI+"""
|
"""Sign up to CrewAI+"""
|
||||||
console.print("Signing Up to CrewAI+ \n", style="bold blue")
|
console.print("Signing Up to CrewAI+ \n", style="bold blue")
|
||||||
|
|
||||||
device_code_data = self._get_device_code()
|
device_code_data = self._get_device_code()
|
||||||
self._display_auth_instructions(device_code_data)
|
self._display_auth_instructions(device_code_data)
|
||||||
|
|
||||||
@@ -29,7 +31,7 @@ class Authentication:
|
|||||||
|
|
||||||
device_code_payload = {
|
device_code_payload = {
|
||||||
"client_id": AUTH0_CLIENT_ID,
|
"client_id": AUTH0_CLIENT_ID,
|
||||||
"scope": "openid profile email",
|
"scope": "openid",
|
||||||
"audience": "https://dev-jzsr0j8zs0atl5ha.us.auth0.com/api/v2/",
|
"audience": "https://dev-jzsr0j8zs0atl5ha.us.auth0.com/api/v2/",
|
||||||
}
|
}
|
||||||
response = requests.post(url=self.DEVICE_CODE_URL, data=device_code_payload)
|
response = requests.post(url=self.DEVICE_CODE_URL, data=device_code_payload)
|
||||||
@@ -42,9 +44,7 @@ class Authentication:
|
|||||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||||
|
|
||||||
def _poll_for_token(
|
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
|
||||||
self, device_code_data: Dict[str, Any]
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""Poll the server for the token."""
|
"""Poll the server for the token."""
|
||||||
token_payload = {
|
token_payload = {
|
||||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||||
@@ -59,13 +59,10 @@ class Authentication:
|
|||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
validate_token(token_data["id_token"])
|
validate_token(token_data["id_token"])
|
||||||
# current_user = jwt.decode(
|
expires_in = 360000 # Token expiration time in seconds
|
||||||
# token_data["id_token"],
|
self.token_manager.save_tokens(token_data["access_token"], expires_in)
|
||||||
# algorithms=ALGORITHMS,
|
|
||||||
# options={"verify_signature": False},
|
|
||||||
# )
|
|
||||||
console.print("\nWelcome to CrewAI+ !!", style="green")
|
console.print("\nWelcome to CrewAI+ !!", style="green")
|
||||||
return token_data
|
return
|
||||||
|
|
||||||
if token_data["error"] not in ("authorization_pending", "slow_down"):
|
if token_data["error"] not in ("authorization_pending", "slow_down"):
|
||||||
raise requests.HTTPError(token_data["error_description"])
|
raise requests.HTTPError(token_data["error_description"])
|
||||||
|
|||||||
@@ -1,7 +1,13 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from auth0.authentication.token_verifier import (
|
from auth0.authentication.token_verifier import (
|
||||||
AsymmetricSignatureVerifier,
|
AsymmetricSignatureVerifier,
|
||||||
TokenVerifier,
|
TokenVerifier,
|
||||||
)
|
)
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
|
from .constants import AUTH0_CLIENT_ID, AUTH0_DOMAIN
|
||||||
|
|
||||||
@@ -19,3 +25,46 @@ def validate_token(id_token: str) -> None:
|
|||||||
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
|
signature_verifier=signature_verifier, issuer=issuer, audience=AUTH0_CLIENT_ID
|
||||||
)
|
)
|
||||||
token_verifier.verify(id_token)
|
token_verifier.verify(id_token)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenManager:
|
||||||
|
def __init__(self, file_path="tokens.enc"):
|
||||||
|
self.file_path = file_path
|
||||||
|
self.key = self._get_or_create_key()
|
||||||
|
self.fernet = Fernet(self.key)
|
||||||
|
|
||||||
|
def _get_or_create_key(self):
|
||||||
|
key_file = "secret.key"
|
||||||
|
if os.path.exists(key_file):
|
||||||
|
return open(key_file, "rb").read()
|
||||||
|
else:
|
||||||
|
key = Fernet.generate_key()
|
||||||
|
with open(key_file, "wb") as key_file:
|
||||||
|
key_file.write(key)
|
||||||
|
return key
|
||||||
|
|
||||||
|
def save_tokens(self, access_token, expires_in):
|
||||||
|
expiration_time = datetime.now() + timedelta(seconds=expires_in)
|
||||||
|
data = {
|
||||||
|
"access_token": access_token,
|
||||||
|
"expiration": expiration_time.isoformat(),
|
||||||
|
}
|
||||||
|
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||||
|
with open(self.file_path, "wb") as file:
|
||||||
|
file.write(encrypted_data)
|
||||||
|
|
||||||
|
def get_token(self) -> Optional[str]:
|
||||||
|
if not os.path.exists(self.file_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(self.file_path, "rb") as file:
|
||||||
|
encrypted_data = file.read()
|
||||||
|
|
||||||
|
decrypted_data = self.fernet.decrypt(encrypted_data)
|
||||||
|
data = json.loads(decrypted_data)
|
||||||
|
|
||||||
|
expiration = datetime.fromisoformat(data["expiration"])
|
||||||
|
if expiration <= datetime.now():
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data["access_token"]
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
|||||||
KickoffTaskOutputsSQLiteStorage,
|
KickoffTaskOutputsSQLiteStorage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .authentication.main import Authentication
|
from .authentication.main import AuthenticationCommand
|
||||||
from .deploy.main import DeployCommand
|
from .deploy.main import DeployCommand
|
||||||
from .evaluate_crew import evaluate_crew
|
from .evaluate_crew import evaluate_crew
|
||||||
from .replay_from_task import replay_task_command
|
from .replay_from_task import replay_task_command
|
||||||
@@ -17,8 +17,6 @@ from .reset_memories_command import reset_memories_command
|
|||||||
from .run_crew import run_crew
|
from .run_crew import run_crew
|
||||||
from .train_crew import train_crew
|
from .train_crew import train_crew
|
||||||
|
|
||||||
deploy_cmd = DeployCommand()
|
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
def crewai():
|
def crewai():
|
||||||
@@ -181,7 +179,7 @@ def run():
|
|||||||
@crewai.command()
|
@crewai.command()
|
||||||
def signup():
|
def signup():
|
||||||
"""Sign Up/Login to CrewAI+."""
|
"""Sign Up/Login to CrewAI+."""
|
||||||
Authentication().signup()
|
AuthenticationCommand().signup()
|
||||||
|
|
||||||
|
|
||||||
# DEPLOY CREWAI+ COMMANDS
|
# DEPLOY CREWAI+ COMMANDS
|
||||||
@@ -194,12 +192,14 @@ def deploy():
|
|||||||
@deploy.command(name="create")
|
@deploy.command(name="create")
|
||||||
def deploy_create():
|
def deploy_create():
|
||||||
"""Create a Crew deployment."""
|
"""Create a Crew deployment."""
|
||||||
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.create_crew()
|
deploy_cmd.create_crew()
|
||||||
|
|
||||||
|
|
||||||
@deploy.command(name="list")
|
@deploy.command(name="list")
|
||||||
def deploy_list():
|
def deploy_list():
|
||||||
"""List all deployments."""
|
"""List all deployments."""
|
||||||
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.list_crews()
|
deploy_cmd.list_crews()
|
||||||
|
|
||||||
|
|
||||||
@@ -207,6 +207,7 @@ def deploy_list():
|
|||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deploy_push(uuid: Optional[str]):
|
def deploy_push(uuid: Optional[str]):
|
||||||
"""Deploy the Crew."""
|
"""Deploy the Crew."""
|
||||||
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.deploy(uuid=uuid)
|
deploy_cmd.deploy(uuid=uuid)
|
||||||
|
|
||||||
|
|
||||||
@@ -214,6 +215,7 @@ def deploy_push(uuid: Optional[str]):
|
|||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deply_status(uuid: Optional[str]):
|
def deply_status(uuid: Optional[str]):
|
||||||
"""Get the status of a deployment."""
|
"""Get the status of a deployment."""
|
||||||
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.get_crew_status(uuid=uuid)
|
deploy_cmd.get_crew_status(uuid=uuid)
|
||||||
|
|
||||||
|
|
||||||
@@ -221,6 +223,7 @@ def deply_status(uuid: Optional[str]):
|
|||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deploy_logs(uuid: Optional[str]):
|
def deploy_logs(uuid: Optional[str]):
|
||||||
"""Get the logs of a deployment."""
|
"""Get the logs of a deployment."""
|
||||||
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.get_crew_logs(uuid=uuid)
|
deploy_cmd.get_crew_logs(uuid=uuid)
|
||||||
|
|
||||||
|
|
||||||
@@ -228,6 +231,7 @@ def deploy_logs(uuid: Optional[str]):
|
|||||||
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
@click.option("-u", "--uuid", type=str, help="Crew UUID parameter")
|
||||||
def deploy_remove(uuid: Optional[str]):
|
def deploy_remove(uuid: Optional[str]):
|
||||||
"""Remove a deployment."""
|
"""Remove a deployment."""
|
||||||
|
deploy_cmd = DeployCommand()
|
||||||
deploy_cmd.remove_crew(uuid=uuid)
|
deploy_cmd.remove_crew(uuid=uuid)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,8 +22,18 @@ class DeployCommand:
|
|||||||
"""
|
"""
|
||||||
Initialize the DeployCommand with project name and API client.
|
Initialize the DeployCommand with project name and API client.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
access_token = get_auth_token()
|
||||||
|
except Exception:
|
||||||
|
console.print(
|
||||||
|
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||||
|
style="bold red",
|
||||||
|
)
|
||||||
|
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
|
||||||
|
raise SystemExit
|
||||||
|
|
||||||
self.project_name = get_project_name()
|
self.project_name = get_project_name()
|
||||||
self.client = CrewAPI(api_key=get_auth_token())
|
self.client = CrewAPI(api_key=access_token)
|
||||||
|
|
||||||
def _handle_error(self, json_response: Dict[str, Any]) -> None:
|
def _handle_error(self, json_response: Dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ import subprocess
|
|||||||
|
|
||||||
import tomllib
|
import tomllib
|
||||||
|
|
||||||
|
from ..authentication.utils import TokenManager
|
||||||
|
|
||||||
|
|
||||||
def get_git_remote_url() -> str:
|
def get_git_remote_url() -> str:
|
||||||
"""Get the Git repository's remote URL."""
|
"""Get the Git repository's remote URL."""
|
||||||
@@ -81,5 +83,9 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def get_auth_token():
|
def get_auth_token() -> str:
|
||||||
return "<token>"
|
"""Get the authentication token."""
|
||||||
|
access_token = TokenManager().get_token()
|
||||||
|
if not access_token:
|
||||||
|
raise Exception()
|
||||||
|
return access_token
|
||||||
|
|||||||
Reference in New Issue
Block a user