feat: reset tokens on crewai config reset (#3365)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit is contained in:
Heitor Carvalho
2025-08-22 17:16:42 -03:00
committed by GitHub
parent 842bed4e9c
commit f96b779df5
10 changed files with 332 additions and 266 deletions

View File

@@ -7,7 +7,8 @@ from rich.console import Console
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from .utils import TokenManager, validate_jwt_token from .utils import validate_jwt_token
from crewai.cli.shared.token_manager import TokenManager
from urllib.parse import quote from urllib.parse import quote
from crewai.cli.plus_api import PlusAPI from crewai.cli.plus_api import PlusAPI
from crewai.cli.config import Settings from crewai.cli.config import Settings
@@ -21,10 +22,19 @@ console = Console()
class Oauth2Settings(BaseModel): class Oauth2Settings(BaseModel):
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).") provider: str = Field(
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.") description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.") )
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None) client_id: str = Field(
description="OAuth2 client ID issued by the provider, used during authentication requests."
)
domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
)
audience: Optional[str] = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.",
default=None,
)
@classmethod @classmethod
def from_settings(cls): def from_settings(cls):
@@ -44,11 +54,15 @@ class ProviderFactory:
settings = settings or Oauth2Settings.from_settings() settings = settings or Oauth2Settings.from_settings()
import importlib import importlib
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
module = importlib.import_module(
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
)
provider = getattr(module, f"{settings.provider.capitalize()}Provider") provider = getattr(module, f"{settings.provider.capitalize()}Provider")
return provider(settings) return provider(settings)
class AuthenticationCommand: class AuthenticationCommand:
def __init__(self): def __init__(self):
self.token_manager = TokenManager() self.token_manager = TokenManager()
@@ -65,7 +79,7 @@ class AuthenticationCommand:
provider="auth0", provider="auth0",
client_id=AUTH0_CLIENT_ID, client_id=AUTH0_CLIENT_ID,
domain=AUTH0_DOMAIN, domain=AUTH0_DOMAIN,
audience=AUTH0_AUDIENCE audience=AUTH0_AUDIENCE,
) )
self.oauth2_provider = ProviderFactory.from_settings(settings) self.oauth2_provider = ProviderFactory.from_settings(settings)
# End of temporary code. # End of temporary code.
@@ -75,9 +89,7 @@ class AuthenticationCommand:
return self._poll_for_token(device_code_data) return self._poll_for_token(device_code_data)
def _get_device_code( def _get_device_code(self) -> Dict[str, Any]:
self
) -> Dict[str, Any]:
"""Get the device code to authenticate the user.""" """Get the device code to authenticate the user."""
device_code_payload = { device_code_payload = {
@@ -86,7 +98,9 @@ class AuthenticationCommand:
"audience": self.oauth2_provider.get_audience(), "audience": self.oauth2_provider.get_audience(),
} }
response = requests.post( response = requests.post(
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20 url=self.oauth2_provider.get_authorize_url(),
data=device_code_payload,
timeout=20,
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
@@ -97,9 +111,7 @@ class AuthenticationCommand:
console.print("2. Enter the following code: ", device_code_data["user_code"]) console.print("2. Enter the following code: ", device_code_data["user_code"])
webbrowser.open(device_code_data["verification_uri_complete"]) webbrowser.open(device_code_data["verification_uri_complete"])
def _poll_for_token( def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
self, device_code_data: Dict[str, Any]
) -> None:
"""Polls the server for the token until it is received, or max attempts are reached.""" """Polls the server for the token until it is received, or max attempts are reached."""
token_payload = { token_payload = {
@@ -112,7 +124,9 @@ class AuthenticationCommand:
attempts = 0 attempts = 0
while True and attempts < 10: while True and attempts < 10:
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30) response = requests.post(
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
)
token_data = response.json() token_data = response.json()
if response.status_code == 200: if response.status_code == 200:

View File

@@ -1,4 +1,4 @@
from .utils import TokenManager from crewai.cli.shared.token_manager import TokenManager
class AuthError(Exception): class AuthError(Exception):

View File

@@ -1,12 +1,5 @@
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional
import jwt import jwt
from jwt import PyJWKClient from jwt import PyJWKClient
from cryptography.fernet import Fernet
def validate_jwt_token( def validate_jwt_token(
@@ -67,118 +60,3 @@ def validate_jwt_token(
raise Exception(f"JWKS or key processing error: {str(e)}") raise Exception(f"JWKS or key processing error: {str(e)}")
except jwt.InvalidTokenError as e: except jwt.InvalidTokenError as e:
raise Exception(f"Invalid token: {str(e)}") raise Exception(f"Invalid token: {str(e)}")
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key.
:return: The encryption key.
"""
key_filename = "secret.key"
key = self.read_secure_file(key_filename)
if key is not None:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
Save the access token and its expiration time.
:param access_token: The access token to save.
:param expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]:
"""
Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return data["access_token"]
def get_secure_storage_path(self) -> Path:
"""
Get the secure storage path based on the operating system.
:return: The secure storage path.
"""
if sys.platform == "win32":
# Windows: Use %LOCALAPPDATA%
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
# macOS: Use ~/Library/Application Support
base_path = os.path.expanduser("~/Library/Application Support")
else:
# Linux and other Unix-like: Use ~/.local/share
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
storage_path = Path(base_path) / app_name
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
:param filename: The name of the file.
:param content: The content to save.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
with open(file_path, "wb") as f:
f.write(content)
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]:
"""
Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
if not file_path.exists():
return None
with open(file_path, "rb") as f:
return f.read()

View File

@@ -11,6 +11,7 @@ from crewai.cli.constants import (
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
) )
from crewai.cli.shared.token_manager import TokenManager
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json" DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
@@ -53,6 +54,7 @@ HIDDEN_SETTINGS_KEYS = [
"tool_repository_password", "tool_repository_password",
] ]
class Settings(BaseModel): class Settings(BaseModel):
enterprise_base_url: Optional[str] = Field( enterprise_base_url: Optional[str] = Field(
default=DEFAULT_CLI_SETTINGS["enterprise_base_url"], default=DEFAULT_CLI_SETTINGS["enterprise_base_url"],
@@ -74,12 +76,12 @@ class Settings(BaseModel):
oauth2_provider: str = Field( oauth2_provider: str = Field(
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).", description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).",
default=DEFAULT_CLI_SETTINGS["oauth2_provider"] default=DEFAULT_CLI_SETTINGS["oauth2_provider"],
) )
oauth2_audience: Optional[str] = Field( oauth2_audience: Optional[str] = Field(
description="OAuth2 audience value, typically used to identify the target API or resource.", description="OAuth2 audience value, typically used to identify the target API or resource.",
default=DEFAULT_CLI_SETTINGS["oauth2_audience"] default=DEFAULT_CLI_SETTINGS["oauth2_audience"],
) )
oauth2_client_id: str = Field( oauth2_client_id: str = Field(
@@ -89,7 +91,7 @@ class Settings(BaseModel):
oauth2_domain: str = Field( oauth2_domain: str = Field(
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.", description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.",
default=DEFAULT_CLI_SETTINGS["oauth2_domain"] default=DEFAULT_CLI_SETTINGS["oauth2_domain"],
) )
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data): def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
@@ -116,6 +118,7 @@ class Settings(BaseModel):
"""Reset all settings to default values""" """Reset all settings to default values"""
self._reset_user_settings() self._reset_user_settings()
self._reset_cli_settings() self._reset_cli_settings()
self._clear_auth_tokens()
self.dump() self.dump()
def dump(self) -> None: def dump(self) -> None:
@@ -139,3 +142,7 @@ class Settings(BaseModel):
"""Reset all CLI settings to default values""" """Reset all CLI settings to default values"""
for key in CLI_SETTINGS_KEYS: for key in CLI_SETTINGS_KEYS:
setattr(self, key, DEFAULT_CLI_SETTINGS.get(key)) setattr(self, key, DEFAULT_CLI_SETTINGS.get(key))
def _clear_auth_tokens(self) -> None:
"""Clear all authentication tokens"""
TokenManager().clear_tokens()

View File

View File

@@ -0,0 +1,139 @@
import json
import os
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional
from cryptography.fernet import Fernet
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key.
:return: The encryption key.
"""
key_filename = "secret.key"
key = self.read_secure_file(key_filename)
if key is not None:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
Save the access token and its expiration time.
:param access_token: The access token to save.
:param expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
"access_token": access_token,
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
def get_token(self) -> Optional[str]:
"""
Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)
decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore
data = json.loads(decrypted_data)
expiration = datetime.fromisoformat(data["expiration"])
if expiration <= datetime.now():
return None
return data["access_token"]
def clear_tokens(self) -> None:
"""
Clear the tokens.
"""
self.delete_secure_file(self.file_path)
def get_secure_storage_path(self) -> Path:
"""
Get the secure storage path based on the operating system.
:return: The secure storage path.
"""
if sys.platform == "win32":
# Windows: Use %LOCALAPPDATA%
base_path = os.environ.get("LOCALAPPDATA")
elif sys.platform == "darwin":
# macOS: Use ~/Library/Application Support
base_path = os.path.expanduser("~/Library/Application Support")
else:
# Linux and other Unix-like: Use ~/.local/share
base_path = os.path.expanduser("~/.local/share")
app_name = "crewai/credentials"
storage_path = Path(base_path) / app_name
storage_path.mkdir(parents=True, exist_ok=True)
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
:param filename: The name of the file.
:param content: The content to save.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
with open(file_path, "wb") as f:
f.write(content)
# Set appropriate permissions (read/write for owner only)
os.chmod(file_path, 0o600)
def read_secure_file(self, filename: str) -> Optional[bytes]:
"""
Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
if not file_path.exists():
return None
with open(file_path, "rb") as f:
return f.read()
def delete_secure_file(self, filename: str) -> None:
"""
Delete the secure file.
:param filename: The name of the file.
"""
storage_path = self.get_secure_storage_path()
file_path = storage_path / filename
if file_path.exists():
file_path.unlink(missing_ok=True)

View File

@@ -1,17 +1,14 @@
import json
import jwt import jwt
import unittest import unittest
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from cryptography.fernet import Fernet
from crewai.cli.authentication.utils import TokenManager, validate_jwt_token from crewai.cli.authentication.utils import validate_jwt_token
@patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock()) @patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock())
@patch("crewai.cli.authentication.utils.jwt") @patch("crewai.cli.authentication.utils.jwt")
class TestValidateToken(unittest.TestCase): class TestUtils(unittest.TestCase):
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient): def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
mock_jwt.decode.return_value = {"exp": 1719859200} mock_jwt.decode.return_value = {"exp": 1719859200}
@@ -105,121 +102,3 @@ class TestValidateToken(unittest.TestCase):
issuer="https://mock_issuer", issuer="https://mock_issuer",
audience="app_id_xxxx", audience="app_id_xxxx",
) )
class TestTokenManager(unittest.TestCase):
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
def setUp(self, mock_get_key):
mock_get_key.return_value = Fernet.generate_key()
self.token_manager = TokenManager()
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
mock_key = Fernet.generate_key()
mock_get_or_create.return_value = mock_key
token_manager = TokenManager()
result = token_manager.key
self.assertEqual(result, mock_key)
@patch("crewai.cli.authentication.utils.Fernet.generate_key")
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate):
mock_key = b"new_key"
mock_read.return_value = None
mock_generate.return_value = mock_key
result = self.token_manager._get_or_create_key()
self.assertEqual(result, mock_key)
mock_read.assert_called_once_with("secret.key")
mock_generate.assert_called_once()
mock_save.assert_called_once_with("secret.key", mock_key)
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
def test_save_tokens(self, mock_save):
access_token = "test_token"
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
self.token_manager.save_tokens(access_token, expires_at)
mock_save.assert_called_once()
args = mock_save.call_args[0]
self.assertEqual(args[0], "tokens.enc")
decrypted_data = self.token_manager.fernet.decrypt(args[1])
data = json.loads(decrypted_data)
self.assertEqual(data["access_token"], access_token)
expiration = datetime.fromisoformat(data["expiration"])
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
def test_get_token_valid(self, mock_read):
access_token = "test_token"
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
mock_read.return_value = encrypted_data
result = self.token_manager.get_token()
self.assertEqual(result, access_token)
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
def test_get_token_expired(self, mock_read):
access_token = "test_token"
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
mock_read.return_value = encrypted_data
result = self.token_manager.get_token()
self.assertIsNone(result)
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
@patch("builtins.open", new_callable=unittest.mock.mock_open)
@patch("crewai.cli.authentication.utils.os.chmod")
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
filename = "test_file.txt"
content = b"test_content"
self.token_manager.save_secure_file(filename, content)
mock_path.__truediv__.assert_called_once_with(filename)
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb")
mock_open().write.assert_called_once_with(content)
mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600)
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
@patch(
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
)
def test_read_secure_file_exists(self, mock_open, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
mock_path.__truediv__.return_value.exists.return_value = True
filename = "test_file.txt"
result = self.token_manager.read_secure_file(filename)
self.assertEqual(result, b"test_content")
mock_path.__truediv__.assert_called_once_with(filename)
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
def test_read_secure_file_not_exists(self, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
mock_path.__truediv__.return_value.exists.return_value = False
filename = "test_file.txt"
result = self.token_manager.read_secure_file(filename)
self.assertIsNone(result)
mock_path.__truediv__.assert_called_once_with(filename)

View File

@@ -3,6 +3,7 @@ import shutil
import tempfile import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch, MagicMock
from crewai.cli.config import ( from crewai.cli.config import (
Settings, Settings,
@@ -10,6 +11,8 @@ from crewai.cli.config import (
CLI_SETTINGS_KEYS, CLI_SETTINGS_KEYS,
DEFAULT_CLI_SETTINGS, DEFAULT_CLI_SETTINGS,
) )
from crewai.cli.shared.token_manager import TokenManager
from datetime import datetime, timedelta
class TestSettings(unittest.TestCase): class TestSettings(unittest.TestCase):
@@ -66,7 +69,8 @@ class TestSettings(unittest.TestCase):
for key in user_settings.keys(): for key in user_settings.keys():
self.assertEqual(getattr(settings, key), None) self.assertEqual(getattr(settings, key), None)
def test_reset_settings(self): @patch("crewai.cli.config.TokenManager")
def test_reset_settings(self, mock_token_manager):
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS} user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS} cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
@@ -74,6 +78,11 @@ class TestSettings(unittest.TestCase):
config_path=self.config_path, **user_settings, **cli_settings config_path=self.config_path, **user_settings, **cli_settings
) )
mock_token_manager.return_value = MagicMock()
TokenManager().save_tokens(
"aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp()
)
settings.reset() settings.reset()
for key in user_settings.keys(): for key in user_settings.keys():
@@ -81,6 +90,8 @@ class TestSettings(unittest.TestCase):
for key in cli_settings.keys(): for key in cli_settings.keys():
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key)) self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
mock_token_manager.return_value.clear_tokens.assert_called_once()
def test_dump_new_settings(self): def test_dump_new_settings(self):
settings = Settings( settings = Settings(
config_path=self.config_path, tool_repository_username="user1" config_path=self.config_path, tool_repository_username="user1"

View File

@@ -0,0 +1,138 @@
import json
import unittest
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
from cryptography.fernet import Fernet
from crewai.cli.shared.token_manager import TokenManager
class TestTokenManager(unittest.TestCase):
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def setUp(self, mock_get_key):
mock_get_key.return_value = Fernet.generate_key()
self.token_manager = TokenManager()
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
mock_key = Fernet.generate_key()
mock_get_or_create.return_value = mock_key
token_manager = TokenManager()
result = token_manager.key
self.assertEqual(result, mock_key)
@patch("crewai.cli.shared.token_manager.Fernet.generate_key")
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate):
mock_key = b"new_key"
mock_read.return_value = None
mock_generate.return_value = mock_key
result = self.token_manager._get_or_create_key()
self.assertEqual(result, mock_key)
mock_read.assert_called_once_with("secret.key")
mock_generate.assert_called_once()
mock_save.assert_called_once_with("secret.key", mock_key)
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
def test_save_tokens(self, mock_save):
access_token = "test_token"
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
self.token_manager.save_tokens(access_token, expires_at)
mock_save.assert_called_once()
args = mock_save.call_args[0]
self.assertEqual(args[0], "tokens.enc")
decrypted_data = self.token_manager.fernet.decrypt(args[1])
data = json.loads(decrypted_data)
self.assertEqual(data["access_token"], access_token)
expiration = datetime.fromisoformat(data["expiration"])
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
def test_get_token_valid(self, mock_read):
access_token = "test_token"
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
mock_read.return_value = encrypted_data
result = self.token_manager.get_token()
self.assertEqual(result, access_token)
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
def test_get_token_expired(self, mock_read):
access_token = "test_token"
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
mock_read.return_value = encrypted_data
result = self.token_manager.get_token()
self.assertIsNone(result)
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
@patch("builtins.open", new_callable=unittest.mock.mock_open)
@patch("crewai.cli.shared.token_manager.os.chmod")
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
filename = "test_file.txt"
content = b"test_content"
self.token_manager.save_secure_file(filename, content)
mock_path.__truediv__.assert_called_once_with(filename)
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb")
mock_open().write.assert_called_once_with(content)
mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600)
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
@patch(
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
)
def test_read_secure_file_exists(self, mock_open, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
mock_path.__truediv__.return_value.exists.return_value = True
filename = "test_file.txt"
result = self.token_manager.read_secure_file(filename)
self.assertEqual(result, b"test_content")
mock_path.__truediv__.assert_called_once_with(filename)
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
def test_read_secure_file_not_exists(self, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
mock_path.__truediv__.return_value.exists.return_value = False
filename = "test_file.txt"
result = self.token_manager.read_secure_file(filename)
self.assertIsNone(result)
mock_path.__truediv__.assert_called_once_with(filename)
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
def test_clear_tokens(self, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
self.token_manager.clear_tokens()
mock_path.__truediv__.assert_called_once_with("tokens.enc")
mock_path.__truediv__.return_value.unlink.assert_called_once_with(
missing_ok=True
)

View File

@@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from pytest import raises from pytest import raises
from crewai.cli.authentication.utils import TokenManager from crewai.cli.shared.token_manager import TokenManager
from crewai.cli.tools.main import ToolCommand from crewai.cli.tools.main import ToolCommand