feat: reset tokens on crewai config reset (#3365)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit is contained in:
Heitor Carvalho
2025-08-22 17:16:42 -03:00
committed by GitHub
parent 842bed4e9c
commit f96b779df5
10 changed files with 332 additions and 266 deletions

View File

@@ -7,7 +7,8 @@ from rich.console import Console
from pydantic import BaseModel, Field
from .utils import TokenManager, validate_jwt_token
from .utils import validate_jwt_token
from crewai.cli.shared.token_manager import TokenManager
from urllib.parse import quote
from crewai.cli.plus_api import PlusAPI
from crewai.cli.config import Settings
@@ -21,10 +22,19 @@ console = Console()
class Oauth2Settings(BaseModel):
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
)
client_id: str = Field(
description="OAuth2 client ID issued by the provider, used during authentication requests."
)
domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
)
audience: Optional[str] = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=None,
)
@classmethod
def from_settings(cls):
@@ -44,11 +54,15 @@ class ProviderFactory:
settings = settings or Oauth2Settings.from_settings()
import importlib
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
module = importlib.import_module(
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
)
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
return provider(settings)
class AuthenticationCommand:
def __init__(self):
self.token_manager = TokenManager()
@@ -65,7 +79,7 @@ class AuthenticationCommand:
provider="auth0",
client_id=AUTH0_CLIENT_ID,
domain=AUTH0_DOMAIN,
audience=AUTH0_AUDIENCE
audience=AUTH0_AUDIENCE,
)
self.oauth2_provider = ProviderFactory.from_settings(settings)
# End of temporary code.
@@ -75,9 +89,7 @@ class AuthenticationCommand:
return self._poll_for_token(device_code_data)
def _get_device_code(
self
) -> Dict[str, Any]:
def _get_device_code(self) -> Dict[str, Any]:
"""Get the device code to authenticate the user."""
device_code_payload = {
@@ -86,7 +98,9 @@ class AuthenticationCommand:
"audience": self.oauth2_provider.get_audience(),
}
response = requests.post(
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
url=self.oauth2_provider.get_authorize_url(),
data=device_code_payload,
timeout=20,
)
response.raise_for_status()
return response.json()
@@ -97,9 +111,7 @@ class AuthenticationCommand:
console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token(
self, device_code_data: Dict[str, Any]
) -> None:
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
"""Polls the server for the token until it is received, or max attempts are reached."""
token_payload = {
@@ -112,7 +124,9 @@ class AuthenticationCommand:
attempts = 0
while True and attempts < 10:
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
response = requests.post(
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
)
token_data = response.json()
if response.status_code == 200:

View File

@@ -1,4 +1,4 @@
from .utils import TokenManager
from crewai.cli.shared.token_manager import TokenManager
class AuthError(Exception):

View File

@@ -1,12 +1,5 @@
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional
import jwt
from jwt import PyJWKClient
from cryptography.fernet import Fernet
def validate_jwt_token(
@@ -67,118 +60,3 @@ def validate_jwt_token(
raise Exception(f"JWKS or key processing error: {str(e)}")
except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {str(e)}")
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key.
:return: The encryption key.
"""
key_filename = "secret.key"
key = self.read_secure_file(key_filename)
if key is not None:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
Save the access token and its expiration time.
:param access_token: The access token to save.
:param expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]:
"""
Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return data["access_token"]
def get_secure_storage_path(self) -> Path:
"""
Get the secure storage path based on the operating system.
:return: The secure storage path.
"""
if sys.platform == "win32":
# Windows: Use %LOCALAPPDATA%
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
# macOS: Use ~/Library/Application Support
base_path = os.path.expanduser("~/Library/Application Support")
else:
# Linux and other Unix-like: Use ~/.local/share
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
storage_path = Path(base_path) / app_name
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
:param filename: The name of the file.
:param content: The content to save.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
with open(file_path, "wb") as f:
f.write(content)
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]:
"""
Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
if not file_path.exists():
return None
with open(file_path, "rb") as f:
return f.read()