feat: reset tokens on crewai config reset (#3365)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit is contained in:
Heitor Carvalho
2025-08-22 17:16:42 -03:00
committed by GitHub
parent 842bed4e9c
commit f96b779df5
10 changed files with 332 additions and 266 deletions

View File

@@ -7,7 +7,8 @@ from rich.console import Console
from pydantic import BaseModel, Field
from .utils import TokenManager, validate_jwt_token
from .utils import validate_jwt_token
from crewai.cli.shared.token_manager import TokenManager
from urllib.parse import quote
from crewai.cli.plus_api import PlusAPI
from crewai.cli.config import Settings
@@ -21,10 +22,19 @@ console = Console()
class Oauth2Settings(BaseModel):
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
)
client_id: str = Field(
description="OAuth2 client ID issued by the provider, used during authentication requests."
)
domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
)
audience: Optional[str] = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=None,
)
@classmethod
def from_settings(cls):
@@ -44,11 +54,15 @@ class ProviderFactory:
settings = settings or Oauth2Settings.from_settings()
import importlib
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
module = importlib.import_module(
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
)
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
return provider(settings)
class AuthenticationCommand:
def __init__(self):
self.token_manager = TokenManager()
@@ -65,7 +79,7 @@ class AuthenticationCommand:
provider="auth0",
client_id=AUTH0_CLIENT_ID,
domain=AUTH0_DOMAIN,
audience=AUTH0_AUDIENCE
audience=AUTH0_AUDIENCE,
)
self.oauth2_provider = ProviderFactory.from_settings(settings)
# End of temporary code.
@@ -75,9 +89,7 @@ class AuthenticationCommand:
return self._poll_for_token(device_code_data)
def _get_device_code(
self
) -> Dict[str, Any]:
def _get_device_code(self) -> Dict[str, Any]:
"""Get the device code to authenticate the user."""
device_code_payload = {
@@ -86,7 +98,9 @@ class AuthenticationCommand:
"audience": self.oauth2_provider.get_audience(),
}
response = requests.post(
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
url=self.oauth2_provider.get_authorize_url(),
data=device_code_payload,
timeout=20,
)
response.raise_for_status()
return response.json()
@@ -97,9 +111,7 @@ class AuthenticationCommand:
console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token(
self, device_code_data: Dict[str, Any]
) -> None:
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
"""Polls the server for the token until it is received, or max attempts are reached."""
token_payload = {
@@ -112,7 +124,9 @@ class AuthenticationCommand:
attempts = 0
while True and attempts < 10:
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
response = requests.post(
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
)
token_data = response.json()
if response.status_code == 200:

View File

@@ -1,4 +1,4 @@
from .utils import TokenManager
from crewai.cli.shared.token_manager import TokenManager
class AuthError(Exception):

View File

@@ -1,12 +1,5 @@
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional
import jwt
from jwt import PyJWKClient
from cryptography.fernet import Fernet
def validate_jwt_token(
@@ -67,118 +60,3 @@ def validate_jwt_token(
raise Exception(f"JWKS or key processing error: {str(e)}")
except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {str(e)}")
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key.
:return: The encryption key.
"""
key_filename = "secret.key"
key = self.read_secure_file(key_filename)
if key is not None:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
Save the access token and its expiration time.
:param access_token: The access token to save.
:param expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]:
"""
Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return data["access_token"]
def get_secure_storage_path(self) -> Path:
"""
Get the secure storage path based on the operating system.
:return: The secure storage path.
"""
if sys.platform == "win32":
# Windows: Use %LOCALAPPDATA%
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
# macOS: Use ~/Library/Application Support
base_path = os.path.expanduser("~/Library/Application Support")
else:
# Linux and other Unix-like: Use ~/.local/share
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
storage_path = Path(base_path) / app_name
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
:param filename: The name of the file.
:param content: The content to save.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
with open(file_path, "wb") as f:
f.write(content)
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]:
"""
Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
if not file_path.exists():
return None
with open(file_path, "rb") as f:
return f.read()

View File

@@ -11,6 +11,7 @@ from crewai.cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
)
from crewai.cli.shared.token_manager import TokenManager
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
@@ -53,6 +54,7 @@ HIDDEN_SETTINGS_KEYS = [
"tool_repository_password",
]
class Settings(BaseModel):
enterprise_base_url: Optional[str] = Field(
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
@@ -74,12 +76,12 @@ class Settings(BaseModel):
oauth2_provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
default=DEFAULT_CLI_SETTINGS["oauth2_provider"]
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
)
oauth2_audience: Optional[str] = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=DEFAULT_CLI_SETTINGS["oauth2_audience"]
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
)
oauth2_client_id: str = Field(
@@ -89,7 +91,7 @@ class Settings(BaseModel):
oauth2_domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
default=DEFAULT_CLI_SETTINGS["oauth2_domain"]
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
)
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
@@ -116,6 +118,7 @@ class Settings(BaseModel):
"""Reset all settings to default values"""
self._reset_user_settings()
self._reset_cli_settings()
self._clear_auth_tokens()
self.dump()
def dump(self) -> None:
@@ -139,3 +142,7 @@ class Settings(BaseModel):
"""Reset all CLI settings to default values"""
for key in CLI_SETTINGS_KEYS:
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
def _clear_auth_tokens(self) -> None:
"""Clear all authentication tokens"""
TokenManager().clear_tokens()

View File

View File

@@ -0,0 +1,139 @@
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional
from cryptography.fernet import Fernet
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key.
:return: The encryption key.
"""
key_filename = "secret.key"
key = self.read_secure_file(key_filename)
if key is not None:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
Save the access token and its expiration time.
:param access_token: The access token to save.
:param expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]:
"""
Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return data["access_token"]
def clear_tokens(self) -> None:
"""
Clear the tokens.
"""
self.delete_secure_file(self.file_path)
def get_secure_storage_path(self) -> Path:
"""
Get the secure storage path based on the operating system.
:return: The secure storage path.
"""
if sys.platform == "win32":
# Windows: Use %LOCALAPPDATA%
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
# macOS: Use ~/Library/Application Support
base_path = os.path.expanduser("~/Library/Application Support")
else:
# Linux and other Unix-like: Use ~/.local/share
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
storage_path = Path(base_path) / app_name
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
:param filename: The name of the file.
:param content: The content to save.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
with open(file_path, "wb") as f:
f.write(content)
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]:
"""
Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
if not file_path.exists():
return None
with open(file_path, "rb") as f:
return f.read()
def delete_secure_file(self, filename: str) -> None:
"""
Delete the secure file.
:param filename: The name of the file.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
if file_path.exists():
file_path.unlink(missing_ok=True)