mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
feat: reset tokens on crewai config reset (#3365)
This commit is contained in:
@@ -7,7 +7,8 @@ from rich.console import Console
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
from .utils import TokenManager, validate_jwt_token
|
from .utils import validate_jwt_token
|
||||||
|
from crewai.cli.shared.token_manager import TokenManager
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
from crewai.cli.plus_api import PlusAPI
|
from crewai.cli.plus_api import PlusAPI
|
||||||
from crewai.cli.config import Settings
|
from crewai.cli.config import Settings
|
||||||
@@ -21,10 +22,19 @@ console = Console()
|
|||||||
|
|
||||||
|
|
||||||
class Oauth2Settings(BaseModel):
|
class Oauth2Settings(BaseModel):
|
||||||
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
|
provider: str = Field(
|
||||||
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
|
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
|
||||||
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
|
)
|
||||||
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
|
client_id: str = Field(
|
||||||
|
description="OAuth2 client ID issued by the provider, used during authentication requests."
|
||||||
|
)
|
||||||
|
domain: str = Field(
|
||||||
|
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
|
||||||
|
)
|
||||||
|
audience: Optional[str] = Field(
|
||||||
|
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_settings(cls):
|
def from_settings(cls):
|
||||||
@@ -44,11 +54,15 @@ class ProviderFactory:
|
|||||||
settings = settings or Oauth2Settings.from_settings()
|
settings = settings or Oauth2Settings.from_settings()
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
|
|
||||||
|
module = importlib.import_module(
|
||||||
|
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||||
|
)
|
||||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||||
|
|
||||||
return provider(settings)
|
return provider(settings)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationCommand:
|
class AuthenticationCommand:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.token_manager = TokenManager()
|
self.token_manager = TokenManager()
|
||||||
@@ -65,7 +79,7 @@ class AuthenticationCommand:
|
|||||||
provider="auth0",
|
provider="auth0",
|
||||||
client_id=AUTH0_CLIENT_ID,
|
client_id=AUTH0_CLIENT_ID,
|
||||||
domain=AUTH0_DOMAIN,
|
domain=AUTH0_DOMAIN,
|
||||||
audience=AUTH0_AUDIENCE
|
audience=AUTH0_AUDIENCE,
|
||||||
)
|
)
|
||||||
self.oauth2_provider = ProviderFactory.from_settings(settings)
|
self.oauth2_provider = ProviderFactory.from_settings(settings)
|
||||||
# End of temporary code.
|
# End of temporary code.
|
||||||
@@ -75,9 +89,7 @@ class AuthenticationCommand:
|
|||||||
|
|
||||||
return self._poll_for_token(device_code_data)
|
return self._poll_for_token(device_code_data)
|
||||||
|
|
||||||
def _get_device_code(
|
def _get_device_code(self) -> Dict[str, Any]:
|
||||||
self
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Get the device code to authenticate the user."""
|
"""Get the device code to authenticate the user."""
|
||||||
|
|
||||||
device_code_payload = {
|
device_code_payload = {
|
||||||
@@ -86,7 +98,9 @@ class AuthenticationCommand:
|
|||||||
"audience": self.oauth2_provider.get_audience(),
|
"audience": self.oauth2_provider.get_audience(),
|
||||||
}
|
}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
|
url=self.oauth2_provider.get_authorize_url(),
|
||||||
|
data=device_code_payload,
|
||||||
|
timeout=20,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
@@ -97,9 +111,7 @@ class AuthenticationCommand:
|
|||||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||||
|
|
||||||
def _poll_for_token(
|
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
|
||||||
self, device_code_data: Dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||||
|
|
||||||
token_payload = {
|
token_payload = {
|
||||||
@@ -112,7 +124,9 @@ class AuthenticationCommand:
|
|||||||
|
|
||||||
attempts = 0
|
attempts = 0
|
||||||
while True and attempts < 10:
|
while True and attempts < 10:
|
||||||
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
|
response = requests.post(
|
||||||
|
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
|
||||||
|
)
|
||||||
token_data = response.json()
|
token_data = response.json()
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from .utils import TokenManager
|
from crewai.cli.shared.token_manager import TokenManager
|
||||||
|
|
||||||
|
|
||||||
class AuthError(Exception):
|
class AuthError(Exception):
|
||||||
|
|||||||
@@ -1,12 +1,5 @@
|
|||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
import jwt
|
import jwt
|
||||||
from jwt import PyJWKClient
|
from jwt import PyJWKClient
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
|
|
||||||
|
|
||||||
def validate_jwt_token(
|
def validate_jwt_token(
|
||||||
@@ -67,118 +60,3 @@ def validate_jwt_token(
|
|||||||
raise Exception(f"JWKS or key processing error: {str(e)}")
|
raise Exception(f"JWKS or key processing error: {str(e)}")
|
||||||
except jwt.InvalidTokenError as e:
|
except jwt.InvalidTokenError as e:
|
||||||
raise Exception(f"Invalid token: {str(e)}")
|
raise Exception(f"Invalid token: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
class TokenManager:
|
|
||||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
|
||||||
"""
|
|
||||||
Initialize the TokenManager class.
|
|
||||||
|
|
||||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
|
||||||
"""
|
|
||||||
self.file_path = file_path
|
|
||||||
self.key = self._get_or_create_key()
|
|
||||||
self.fernet = Fernet(self.key)
|
|
||||||
|
|
||||||
def _get_or_create_key(self) -> bytes:
|
|
||||||
"""
|
|
||||||
Get or create the encryption key.
|
|
||||||
|
|
||||||
:return: The encryption key.
|
|
||||||
"""
|
|
||||||
key_filename = "secret.key"
|
|
||||||
key = self.read_secure_file(key_filename)
|
|
||||||
|
|
||||||
if key is not None:
|
|
||||||
return key
|
|
||||||
|
|
||||||
new_key = Fernet.generate_key()
|
|
||||||
self.save_secure_file(key_filename, new_key)
|
|
||||||
return new_key
|
|
||||||
|
|
||||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
|
||||||
"""
|
|
||||||
Save the access token and its expiration time.
|
|
||||||
|
|
||||||
:param access_token: The access token to save.
|
|
||||||
:param expires_at: The UNIX timestamp of the expiration time.
|
|
||||||
"""
|
|
||||||
expiration_time = datetime.fromtimestamp(expires_at)
|
|
||||||
data = {
|
|
||||||
"access_token": access_token,
|
|
||||||
"expiration": expiration_time.isoformat(),
|
|
||||||
}
|
|
||||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
|
||||||
self.save_secure_file(self.file_path, encrypted_data)
|
|
||||||
|
|
||||||
def get_token(self) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get the access token if it is valid and not expired.
|
|
||||||
|
|
||||||
:return: The access token if valid and not expired, otherwise None.
|
|
||||||
"""
|
|
||||||
encrypted_data = self.read_secure_file(self.file_path)
|
|
||||||
|
|
||||||
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
|
||||||
data = json.loads(decrypted_data)
|
|
||||||
|
|
||||||
expiration = datetime.fromisoformat(data["expiration"])
|
|
||||||
if expiration <= datetime.now():
|
|
||||||
return None
|
|
||||||
|
|
||||||
return data["access_token"]
|
|
||||||
|
|
||||||
def get_secure_storage_path(self) -> Path:
|
|
||||||
"""
|
|
||||||
Get the secure storage path based on the operating system.
|
|
||||||
|
|
||||||
:return: The secure storage path.
|
|
||||||
"""
|
|
||||||
if sys.platform == "win32":
|
|
||||||
# Windows: Use %LOCALAPPDATA%
|
|
||||||
base_path = os.environ.get("LOCALAPPDATA")
|
|
||||||
elif sys.platform == "darwin":
|
|
||||||
# macOS: Use ~/Library/Application Support
|
|
||||||
base_path = os.path.expanduser("~/Library/Application Support")
|
|
||||||
else:
|
|
||||||
# Linux and other Unix-like: Use ~/.local/share
|
|
||||||
base_path = os.path.expanduser("~/.local/share")
|
|
||||||
|
|
||||||
app_name = "crewai/credentials"
|
|
||||||
storage_path = Path(base_path) / app_name
|
|
||||||
|
|
||||||
storage_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
return storage_path
|
|
||||||
|
|
||||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
|
||||||
"""
|
|
||||||
Save the content to a secure file.
|
|
||||||
|
|
||||||
:param filename: The name of the file.
|
|
||||||
:param content: The content to save.
|
|
||||||
"""
|
|
||||||
storage_path = self.get_secure_storage_path()
|
|
||||||
file_path = storage_path / filename
|
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
# Set appropriate permissions (read/write for owner only)
|
|
||||||
os.chmod(file_path, 0o600)
|
|
||||||
|
|
||||||
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
|
||||||
"""
|
|
||||||
Read the content of a secure file.
|
|
||||||
|
|
||||||
:param filename: The name of the file.
|
|
||||||
:return: The content of the file if it exists, otherwise None.
|
|
||||||
"""
|
|
||||||
storage_path = self.get_secure_storage_path()
|
|
||||||
file_path = storage_path / filename
|
|
||||||
|
|
||||||
if not file_path.exists():
|
|
||||||
return None
|
|
||||||
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
return f.read()
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from crewai.cli.constants import (
|
|||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||||
)
|
)
|
||||||
|
from crewai.cli.shared.token_manager import TokenManager
|
||||||
|
|
||||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||||
|
|
||||||
@@ -53,6 +54,7 @@ HIDDEN_SETTINGS_KEYS = [
|
|||||||
"tool_repository_password",
|
"tool_repository_password",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseModel):
|
class Settings(BaseModel):
|
||||||
enterprise_base_url: Optional[str] = Field(
|
enterprise_base_url: Optional[str] = Field(
|
||||||
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
|
||||||
@@ -74,12 +76,12 @@ class Settings(BaseModel):
|
|||||||
|
|
||||||
oauth2_provider: str = Field(
|
oauth2_provider: str = Field(
|
||||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
|
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
|
||||||
default=DEFAULT_CLI_SETTINGS["oauth2_provider"]
|
default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
|
||||||
)
|
)
|
||||||
|
|
||||||
oauth2_audience: Optional[str] = Field(
|
oauth2_audience: Optional[str] = Field(
|
||||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||||
default=DEFAULT_CLI_SETTINGS["oauth2_audience"]
|
default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
|
||||||
)
|
)
|
||||||
|
|
||||||
oauth2_client_id: str = Field(
|
oauth2_client_id: str = Field(
|
||||||
@@ -89,7 +91,7 @@ class Settings(BaseModel):
|
|||||||
|
|
||||||
oauth2_domain: str = Field(
|
oauth2_domain: str = Field(
|
||||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
|
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
|
||||||
default=DEFAULT_CLI_SETTINGS["oauth2_domain"]
|
default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||||
@@ -116,6 +118,7 @@ class Settings(BaseModel):
|
|||||||
"""Reset all settings to default values"""
|
"""Reset all settings to default values"""
|
||||||
self._reset_user_settings()
|
self._reset_user_settings()
|
||||||
self._reset_cli_settings()
|
self._reset_cli_settings()
|
||||||
|
self._clear_auth_tokens()
|
||||||
self.dump()
|
self.dump()
|
||||||
|
|
||||||
def dump(self) -> None:
|
def dump(self) -> None:
|
||||||
@@ -139,3 +142,7 @@ class Settings(BaseModel):
|
|||||||
"""Reset all CLI settings to default values"""
|
"""Reset all CLI settings to default values"""
|
||||||
for key in CLI_SETTINGS_KEYS:
|
for key in CLI_SETTINGS_KEYS:
|
||||||
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
|
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
|
||||||
|
|
||||||
|
def _clear_auth_tokens(self) -> None:
|
||||||
|
"""Clear all authentication tokens"""
|
||||||
|
TokenManager().clear_tokens()
|
||||||
|
|||||||
0
src/crewai/cli/shared/__init__.py
Normal file
0
src/crewai/cli/shared/__init__.py
Normal file
139
src/crewai/cli/shared/token_manager.py
Normal file
139
src/crewai/cli/shared/token_manager.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
|
||||||
|
class TokenManager:
|
||||||
|
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||||
|
"""
|
||||||
|
Initialize the TokenManager class.
|
||||||
|
|
||||||
|
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||||
|
"""
|
||||||
|
self.file_path = file_path
|
||||||
|
self.key = self._get_or_create_key()
|
||||||
|
self.fernet = Fernet(self.key)
|
||||||
|
|
||||||
|
def _get_or_create_key(self) -> bytes:
|
||||||
|
"""
|
||||||
|
Get or create the encryption key.
|
||||||
|
|
||||||
|
:return: The encryption key.
|
||||||
|
"""
|
||||||
|
key_filename = "secret.key"
|
||||||
|
key = self.read_secure_file(key_filename)
|
||||||
|
|
||||||
|
if key is not None:
|
||||||
|
return key
|
||||||
|
|
||||||
|
new_key = Fernet.generate_key()
|
||||||
|
self.save_secure_file(key_filename, new_key)
|
||||||
|
return new_key
|
||||||
|
|
||||||
|
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||||
|
"""
|
||||||
|
Save the access token and its expiration time.
|
||||||
|
|
||||||
|
:param access_token: The access token to save.
|
||||||
|
:param expires_at: The UNIX timestamp of the expiration time.
|
||||||
|
"""
|
||||||
|
expiration_time = datetime.fromtimestamp(expires_at)
|
||||||
|
data = {
|
||||||
|
"access_token": access_token,
|
||||||
|
"expiration": expiration_time.isoformat(),
|
||||||
|
}
|
||||||
|
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||||
|
self.save_secure_file(self.file_path, encrypted_data)
|
||||||
|
|
||||||
|
def get_token(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the access token if it is valid and not expired.
|
||||||
|
|
||||||
|
:return: The access token if valid and not expired, otherwise None.
|
||||||
|
"""
|
||||||
|
encrypted_data = self.read_secure_file(self.file_path)
|
||||||
|
|
||||||
|
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
|
||||||
|
data = json.loads(decrypted_data)
|
||||||
|
|
||||||
|
expiration = datetime.fromisoformat(data["expiration"])
|
||||||
|
if expiration <= datetime.now():
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data["access_token"]
|
||||||
|
|
||||||
|
def clear_tokens(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear the tokens.
|
||||||
|
"""
|
||||||
|
self.delete_secure_file(self.file_path)
|
||||||
|
|
||||||
|
def get_secure_storage_path(self) -> Path:
|
||||||
|
"""
|
||||||
|
Get the secure storage path based on the operating system.
|
||||||
|
|
||||||
|
:return: The secure storage path.
|
||||||
|
"""
|
||||||
|
if sys.platform == "win32":
|
||||||
|
# Windows: Use %LOCALAPPDATA%
|
||||||
|
base_path = os.environ.get("LOCALAPPDATA")
|
||||||
|
elif sys.platform == "darwin":
|
||||||
|
# macOS: Use ~/Library/Application Support
|
||||||
|
base_path = os.path.expanduser("~/Library/Application Support")
|
||||||
|
else:
|
||||||
|
# Linux and other Unix-like: Use ~/.local/share
|
||||||
|
base_path = os.path.expanduser("~/.local/share")
|
||||||
|
|
||||||
|
app_name = "crewai/credentials"
|
||||||
|
storage_path = Path(base_path) / app_name
|
||||||
|
|
||||||
|
storage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
return storage_path
|
||||||
|
|
||||||
|
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||||
|
"""
|
||||||
|
Save the content to a secure file.
|
||||||
|
|
||||||
|
:param filename: The name of the file.
|
||||||
|
:param content: The content to save.
|
||||||
|
"""
|
||||||
|
storage_path = self.get_secure_storage_path()
|
||||||
|
file_path = storage_path / filename
|
||||||
|
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
# Set appropriate permissions (read/write for owner only)
|
||||||
|
os.chmod(file_path, 0o600)
|
||||||
|
|
||||||
|
def read_secure_file(self, filename: str) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
Read the content of a secure file.
|
||||||
|
|
||||||
|
:param filename: The name of the file.
|
||||||
|
:return: The content of the file if it exists, otherwise None.
|
||||||
|
"""
|
||||||
|
storage_path = self.get_secure_storage_path()
|
||||||
|
file_path = storage_path / filename
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def delete_secure_file(self, filename: str) -> None:
|
||||||
|
"""
|
||||||
|
Delete the secure file.
|
||||||
|
|
||||||
|
:param filename: The name of the file.
|
||||||
|
"""
|
||||||
|
storage_path = self.get_secure_storage_path()
|
||||||
|
file_path = storage_path / filename
|
||||||
|
if file_path.exists():
|
||||||
|
file_path.unlink(missing_ok=True)
|
||||||
@@ -1,17 +1,14 @@
|
|||||||
import json
|
|
||||||
import jwt
|
import jwt
|
||||||
import unittest
|
import unittest
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
|
|
||||||
from crewai.cli.authentication.utils import TokenManager, validate_jwt_token
|
from crewai.cli.authentication.utils import validate_jwt_token
|
||||||
|
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock())
|
@patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock())
|
||||||
@patch("crewai.cli.authentication.utils.jwt")
|
@patch("crewai.cli.authentication.utils.jwt")
|
||||||
class TestValidateToken(unittest.TestCase):
|
class TestUtils(unittest.TestCase):
|
||||||
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
|
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
|
||||||
mock_jwt.decode.return_value = {"exp": 1719859200}
|
mock_jwt.decode.return_value = {"exp": 1719859200}
|
||||||
|
|
||||||
@@ -105,121 +102,3 @@ class TestValidateToken(unittest.TestCase):
|
|||||||
issuer="https://mock_issuer",
|
issuer="https://mock_issuer",
|
||||||
audience="app_id_xxxx",
|
audience="app_id_xxxx",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestTokenManager(unittest.TestCase):
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
|
|
||||||
def setUp(self, mock_get_key):
|
|
||||||
mock_get_key.return_value = Fernet.generate_key()
|
|
||||||
self.token_manager = TokenManager()
|
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
|
|
||||||
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
|
|
||||||
mock_key = Fernet.generate_key()
|
|
||||||
mock_get_or_create.return_value = mock_key
|
|
||||||
|
|
||||||
token_manager = TokenManager()
|
|
||||||
result = token_manager.key
|
|
||||||
|
|
||||||
self.assertEqual(result, mock_key)
|
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.utils.Fernet.generate_key")
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
|
||||||
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate):
|
|
||||||
mock_key = b"new_key"
|
|
||||||
mock_read.return_value = None
|
|
||||||
mock_generate.return_value = mock_key
|
|
||||||
|
|
||||||
result = self.token_manager._get_or_create_key()
|
|
||||||
|
|
||||||
self.assertEqual(result, mock_key)
|
|
||||||
mock_read.assert_called_once_with("secret.key")
|
|
||||||
mock_generate.assert_called_once()
|
|
||||||
mock_save.assert_called_once_with("secret.key", mock_key)
|
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
|
||||||
def test_save_tokens(self, mock_save):
|
|
||||||
access_token = "test_token"
|
|
||||||
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
|
|
||||||
|
|
||||||
self.token_manager.save_tokens(access_token, expires_at)
|
|
||||||
|
|
||||||
mock_save.assert_called_once()
|
|
||||||
args = mock_save.call_args[0]
|
|
||||||
self.assertEqual(args[0], "tokens.enc")
|
|
||||||
decrypted_data = self.token_manager.fernet.decrypt(args[1])
|
|
||||||
data = json.loads(decrypted_data)
|
|
||||||
self.assertEqual(data["access_token"], access_token)
|
|
||||||
expiration = datetime.fromisoformat(data["expiration"])
|
|
||||||
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
|
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
|
||||||
def test_get_token_valid(self, mock_read):
|
|
||||||
access_token = "test_token"
|
|
||||||
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
|
|
||||||
data = {"access_token": access_token, "expiration": expiration}
|
|
||||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
|
||||||
mock_read.return_value = encrypted_data
|
|
||||||
|
|
||||||
result = self.token_manager.get_token()
|
|
||||||
|
|
||||||
self.assertEqual(result, access_token)
|
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
|
||||||
def test_get_token_expired(self, mock_read):
|
|
||||||
access_token = "test_token"
|
|
||||||
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
|
|
||||||
data = {"access_token": access_token, "expiration": expiration}
|
|
||||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
|
||||||
mock_read.return_value = encrypted_data
|
|
||||||
|
|
||||||
result = self.token_manager.get_token()
|
|
||||||
|
|
||||||
self.assertIsNone(result)
|
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
|
||||||
@patch("builtins.open", new_callable=unittest.mock.mock_open)
|
|
||||||
@patch("crewai.cli.authentication.utils.os.chmod")
|
|
||||||
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
|
|
||||||
mock_path = MagicMock()
|
|
||||||
mock_get_path.return_value = mock_path
|
|
||||||
filename = "test_file.txt"
|
|
||||||
content = b"test_content"
|
|
||||||
|
|
||||||
self.token_manager.save_secure_file(filename, content)
|
|
||||||
|
|
||||||
mock_path.__truediv__.assert_called_once_with(filename)
|
|
||||||
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb")
|
|
||||||
mock_open().write.assert_called_once_with(content)
|
|
||||||
mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600)
|
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
|
||||||
@patch(
|
|
||||||
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
|
|
||||||
)
|
|
||||||
def test_read_secure_file_exists(self, mock_open, mock_get_path):
|
|
||||||
mock_path = MagicMock()
|
|
||||||
mock_get_path.return_value = mock_path
|
|
||||||
mock_path.__truediv__.return_value.exists.return_value = True
|
|
||||||
filename = "test_file.txt"
|
|
||||||
|
|
||||||
result = self.token_manager.read_secure_file(filename)
|
|
||||||
|
|
||||||
self.assertEqual(result, b"test_content")
|
|
||||||
mock_path.__truediv__.assert_called_once_with(filename)
|
|
||||||
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
|
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
|
||||||
def test_read_secure_file_not_exists(self, mock_get_path):
|
|
||||||
mock_path = MagicMock()
|
|
||||||
mock_get_path.return_value = mock_path
|
|
||||||
mock_path.__truediv__.return_value.exists.return_value = False
|
|
||||||
filename = "test_file.txt"
|
|
||||||
|
|
||||||
result = self.token_manager.read_secure_file(filename)
|
|
||||||
|
|
||||||
self.assertIsNone(result)
|
|
||||||
mock_path.__truediv__.assert_called_once_with(filename)
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
from crewai.cli.config import (
|
from crewai.cli.config import (
|
||||||
Settings,
|
Settings,
|
||||||
@@ -10,6 +11,8 @@ from crewai.cli.config import (
|
|||||||
CLI_SETTINGS_KEYS,
|
CLI_SETTINGS_KEYS,
|
||||||
DEFAULT_CLI_SETTINGS,
|
DEFAULT_CLI_SETTINGS,
|
||||||
)
|
)
|
||||||
|
from crewai.cli.shared.token_manager import TokenManager
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
class TestSettings(unittest.TestCase):
|
class TestSettings(unittest.TestCase):
|
||||||
@@ -66,7 +69,8 @@ class TestSettings(unittest.TestCase):
|
|||||||
for key in user_settings.keys():
|
for key in user_settings.keys():
|
||||||
self.assertEqual(getattr(settings, key), None)
|
self.assertEqual(getattr(settings, key), None)
|
||||||
|
|
||||||
def test_reset_settings(self):
|
@patch("crewai.cli.config.TokenManager")
|
||||||
|
def test_reset_settings(self, mock_token_manager):
|
||||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
||||||
|
|
||||||
@@ -74,6 +78,11 @@ class TestSettings(unittest.TestCase):
|
|||||||
config_path=self.config_path, **user_settings, **cli_settings
|
config_path=self.config_path, **user_settings, **cli_settings
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mock_token_manager.return_value = MagicMock()
|
||||||
|
TokenManager().save_tokens(
|
||||||
|
"aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
settings.reset()
|
settings.reset()
|
||||||
|
|
||||||
for key in user_settings.keys():
|
for key in user_settings.keys():
|
||||||
@@ -81,6 +90,8 @@ class TestSettings(unittest.TestCase):
|
|||||||
for key in cli_settings.keys():
|
for key in cli_settings.keys():
|
||||||
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
|
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
|
||||||
|
|
||||||
|
mock_token_manager.return_value.clear_tokens.assert_called_once()
|
||||||
|
|
||||||
def test_dump_new_settings(self):
|
def test_dump_new_settings(self):
|
||||||
settings = Settings(
|
settings = Settings(
|
||||||
config_path=self.config_path, tool_repository_username="user1"
|
config_path=self.config_path, tool_repository_username="user1"
|
||||||
|
|||||||
138
tests/cli/test_token_manager.py
Normal file
138
tests/cli/test_token_manager.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
|
||||||
|
from crewai.cli.shared.token_manager import TokenManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenManager(unittest.TestCase):
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||||
|
def setUp(self, mock_get_key):
|
||||||
|
mock_get_key.return_value = Fernet.generate_key()
|
||||||
|
self.token_manager = TokenManager()
|
||||||
|
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||||
|
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
|
||||||
|
mock_key = Fernet.generate_key()
|
||||||
|
mock_get_or_create.return_value = mock_key
|
||||||
|
|
||||||
|
token_manager = TokenManager()
|
||||||
|
result = token_manager.key
|
||||||
|
|
||||||
|
self.assertEqual(result, mock_key)
|
||||||
|
|
||||||
|
@patch("crewai.cli.shared.token_manager.Fernet.generate_key")
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
|
||||||
|
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate):
|
||||||
|
mock_key = b"new_key"
|
||||||
|
mock_read.return_value = None
|
||||||
|
mock_generate.return_value = mock_key
|
||||||
|
|
||||||
|
result = self.token_manager._get_or_create_key()
|
||||||
|
|
||||||
|
self.assertEqual(result, mock_key)
|
||||||
|
mock_read.assert_called_once_with("secret.key")
|
||||||
|
mock_generate.assert_called_once()
|
||||||
|
mock_save.assert_called_once_with("secret.key", mock_key)
|
||||||
|
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
|
||||||
|
def test_save_tokens(self, mock_save):
|
||||||
|
access_token = "test_token"
|
||||||
|
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
|
||||||
|
|
||||||
|
self.token_manager.save_tokens(access_token, expires_at)
|
||||||
|
|
||||||
|
mock_save.assert_called_once()
|
||||||
|
args = mock_save.call_args[0]
|
||||||
|
self.assertEqual(args[0], "tokens.enc")
|
||||||
|
decrypted_data = self.token_manager.fernet.decrypt(args[1])
|
||||||
|
data = json.loads(decrypted_data)
|
||||||
|
self.assertEqual(data["access_token"], access_token)
|
||||||
|
expiration = datetime.fromisoformat(data["expiration"])
|
||||||
|
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
|
||||||
|
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||||
|
def test_get_token_valid(self, mock_read):
|
||||||
|
access_token = "test_token"
|
||||||
|
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||||
|
data = {"access_token": access_token, "expiration": expiration}
|
||||||
|
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||||
|
mock_read.return_value = encrypted_data
|
||||||
|
|
||||||
|
result = self.token_manager.get_token()
|
||||||
|
|
||||||
|
self.assertEqual(result, access_token)
|
||||||
|
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
|
||||||
|
def test_get_token_expired(self, mock_read):
|
||||||
|
access_token = "test_token"
|
||||||
|
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||||
|
data = {"access_token": access_token, "expiration": expiration}
|
||||||
|
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||||
|
mock_read.return_value = encrypted_data
|
||||||
|
|
||||||
|
result = self.token_manager.get_token()
|
||||||
|
|
||||||
|
self.assertIsNone(result)
|
||||||
|
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||||
|
@patch("builtins.open", new_callable=unittest.mock.mock_open)
|
||||||
|
@patch("crewai.cli.shared.token_manager.os.chmod")
|
||||||
|
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
|
||||||
|
mock_path = MagicMock()
|
||||||
|
mock_get_path.return_value = mock_path
|
||||||
|
filename = "test_file.txt"
|
||||||
|
content = b"test_content"
|
||||||
|
|
||||||
|
self.token_manager.save_secure_file(filename, content)
|
||||||
|
|
||||||
|
mock_path.__truediv__.assert_called_once_with(filename)
|
||||||
|
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb")
|
||||||
|
mock_open().write.assert_called_once_with(content)
|
||||||
|
mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600)
|
||||||
|
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||||
|
@patch(
|
||||||
|
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
|
||||||
|
)
|
||||||
|
def test_read_secure_file_exists(self, mock_open, mock_get_path):
|
||||||
|
mock_path = MagicMock()
|
||||||
|
mock_get_path.return_value = mock_path
|
||||||
|
mock_path.__truediv__.return_value.exists.return_value = True
|
||||||
|
filename = "test_file.txt"
|
||||||
|
|
||||||
|
result = self.token_manager.read_secure_file(filename)
|
||||||
|
|
||||||
|
self.assertEqual(result, b"test_content")
|
||||||
|
mock_path.__truediv__.assert_called_once_with(filename)
|
||||||
|
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
|
||||||
|
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||||
|
def test_read_secure_file_not_exists(self, mock_get_path):
|
||||||
|
mock_path = MagicMock()
|
||||||
|
mock_get_path.return_value = mock_path
|
||||||
|
mock_path.__truediv__.return_value.exists.return_value = False
|
||||||
|
filename = "test_file.txt"
|
||||||
|
|
||||||
|
result = self.token_manager.read_secure_file(filename)
|
||||||
|
|
||||||
|
self.assertIsNone(result)
|
||||||
|
mock_path.__truediv__.assert_called_once_with(filename)
|
||||||
|
|
||||||
|
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
|
||||||
|
def test_clear_tokens(self, mock_get_path):
|
||||||
|
mock_path = MagicMock()
|
||||||
|
mock_get_path.return_value = mock_path
|
||||||
|
|
||||||
|
self.token_manager.clear_tokens()
|
||||||
|
|
||||||
|
mock_path.__truediv__.assert_called_once_with("tokens.enc")
|
||||||
|
mock_path.__truediv__.return_value.unlink.assert_called_once_with(
|
||||||
|
missing_ok=True
|
||||||
|
)
|
||||||
@@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
from pytest import raises
|
from pytest import raises
|
||||||
|
|
||||||
from crewai.cli.authentication.utils import TokenManager
|
from crewai.cli.shared.token_manager import TokenManager
|
||||||
from crewai.cli.tools.main import ToolCommand
|
from crewai.cli.tools.main import ToolCommand
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user