From f96b779df56865e5e9ca689da5ddf52cedeb69ef Mon Sep 17 00:00:00 2001 From: Heitor Carvalho Date: Fri, 22 Aug 2025 17:16:42 -0300 Subject: [PATCH] feat: reset tokens on `crewai config reset` (#3365) --- src/crewai/cli/authentication/main.py | 44 +++++--- src/crewai/cli/authentication/token.py | 2 +- src/crewai/cli/authentication/utils.py | 122 ---------------------- src/crewai/cli/config.py | 13 ++- src/crewai/cli/shared/__init__.py | 0 src/crewai/cli/shared/token_manager.py | 139 +++++++++++++++++++++++++ tests/cli/authentication/test_utils.py | 125 +--------------------- tests/cli/test_config.py | 13 ++- tests/cli/test_token_manager.py | 138 ++++++++++++++++++++++++ tests/cli/tools/test_main.py | 2 +- 10 files changed, 332 insertions(+), 266 deletions(-) create mode 100644 src/crewai/cli/shared/__init__.py create mode 100644 src/crewai/cli/shared/token_manager.py create mode 100644 tests/cli/test_token_manager.py diff --git a/src/crewai/cli/authentication/main.py b/src/crewai/cli/authentication/main.py index 26a354aea..42dd6677b 100644 --- a/src/crewai/cli/authentication/main.py +++ b/src/crewai/cli/authentication/main.py @@ -7,7 +7,8 @@ from rich.console import Console from pydantic import BaseModel, Field -from .utils import TokenManager, validate_jwt_token +from .utils import validate_jwt_token +from crewai.cli.shared.token_manager import TokenManager from urllib.parse import quote from crewai.cli.plus_api import PlusAPI from crewai.cli.config import Settings @@ -21,10 +22,19 @@ console = Console() class Oauth2Settings(BaseModel): - provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).") - client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.") - domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.") - audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None) + provider: str = Field( + description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)." + ) + client_id: str = Field( + description="OAuth2 client ID issued by the provider, used during authentication requests." + ) + domain: str = Field( + description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens." + ) + audience: Optional[str] = Field( + description="OAuth2 audience value, typically used to identify the target API or resource.", + default=None, + ) @classmethod def from_settings(cls): @@ -44,11 +54,15 @@ class ProviderFactory: settings = settings or Oauth2Settings.from_settings() import importlib - module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}") + + module = importlib.import_module( + f"crewai.cli.authentication.providers.{settings.provider.lower()}" + ) provider = getattr(module, f"{settings.provider.capitalize()}Provider") return provider(settings) + class AuthenticationCommand: def __init__(self): self.token_manager = TokenManager() @@ -65,7 +79,7 @@ class AuthenticationCommand: provider="auth0", client_id=AUTH0_CLIENT_ID, domain=AUTH0_DOMAIN, - audience=AUTH0_AUDIENCE + audience=AUTH0_AUDIENCE, ) self.oauth2_provider = ProviderFactory.from_settings(settings) # End of temporary code. @@ -75,9 +89,7 @@ class AuthenticationCommand: return self._poll_for_token(device_code_data) - def _get_device_code( - self - ) -> Dict[str, Any]: + def _get_device_code(self) -> Dict[str, Any]: """Get the device code to authenticate the user.""" device_code_payload = { @@ -86,7 +98,9 @@ class AuthenticationCommand: "audience": self.oauth2_provider.get_audience(), } response = requests.post( - url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20 + url=self.oauth2_provider.get_authorize_url(), + data=device_code_payload, + timeout=20, ) response.raise_for_status() return response.json() @@ -97,9 +111,7 @@ class AuthenticationCommand: console.print("2. Enter the following code: ", device_code_data["user_code"]) webbrowser.open(device_code_data["verification_uri_complete"]) - def _poll_for_token( - self, device_code_data: Dict[str, Any] - ) -> None: + def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None: """Polls the server for the token until it is received, or max attempts are reached.""" token_payload = { @@ -112,7 +124,9 @@ class AuthenticationCommand: attempts = 0 while True and attempts < 10: - response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30) + response = requests.post( + self.oauth2_provider.get_token_url(), data=token_payload, timeout=30 + ) token_data = response.json() if response.status_code == 200: diff --git a/src/crewai/cli/authentication/token.py b/src/crewai/cli/authentication/token.py index e59a13b38..7a1d05c98 100644 --- a/src/crewai/cli/authentication/token.py +++ b/src/crewai/cli/authentication/token.py @@ -1,4 +1,4 @@ -from .utils import TokenManager +from crewai.cli.shared.token_manager import TokenManager class AuthError(Exception): diff --git a/src/crewai/cli/authentication/utils.py b/src/crewai/cli/authentication/utils.py index 8b632ba37..a788cccc0 100644 --- a/src/crewai/cli/authentication/utils.py +++ b/src/crewai/cli/authentication/utils.py @@ -1,12 +1,5 @@ -import json -import os -import sys -from datetime import datetime -from pathlib import Path -from typing import Optional import jwt from jwt import PyJWKClient -from cryptography.fernet import Fernet def validate_jwt_token( @@ -67,118 +60,3 @@ def validate_jwt_token( raise Exception(f"JWKS or key processing error: {str(e)}") except jwt.InvalidTokenError as e: raise Exception(f"Invalid token: {str(e)}") - - -class TokenManager: - def __init__(self, file_path: str = "tokens.enc") -> None: - """ - Initialize the TokenManager class. - - :param file_path: The file path to store the encrypted tokens. Default is "tokens.enc". - """ - self.file_path = file_path - self.key = self._get_or_create_key() - self.fernet = Fernet(self.key) - - def _get_or_create_key(self) -> bytes: - """ - Get or create the encryption key. - - :return: The encryption key. - """ - key_filename = "secret.key" - key = self.read_secure_file(key_filename) - - if key is not None: - return key - - new_key = Fernet.generate_key() - self.save_secure_file(key_filename, new_key) - return new_key - - def save_tokens(self, access_token: str, expires_at: int) -> None: - """ - Save the access token and its expiration time. - - :param access_token: The access token to save. - :param expires_at: The UNIX timestamp of the expiration time. - """ - expiration_time = datetime.fromtimestamp(expires_at) - data = { - "access_token": access_token, - "expiration": expiration_time.isoformat(), - } - encrypted_data = self.fernet.encrypt(json.dumps(data).encode()) - self.save_secure_file(self.file_path, encrypted_data) - - def get_token(self) -> Optional[str]: - """ - Get the access token if it is valid and not expired. - - :return: The access token if valid and not expired, otherwise None. - """ - encrypted_data = self.read_secure_file(self.file_path) - - decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore - data = json.loads(decrypted_data) - - expiration = datetime.fromisoformat(data["expiration"]) - if expiration <= datetime.now(): - return None - - return data["access_token"] - - def get_secure_storage_path(self) -> Path: - """ - Get the secure storage path based on the operating system. - - :return: The secure storage path. - """ - if sys.platform == "win32": - # Windows: Use %LOCALAPPDATA% - base_path = os.environ.get("LOCALAPPDATA") - elif sys.platform == "darwin": - # macOS: Use ~/Library/Application Support - base_path = os.path.expanduser("~/Library/Application Support") - else: - # Linux and other Unix-like: Use ~/.local/share - base_path = os.path.expanduser("~/.local/share") - - app_name = "crewai/credentials" - storage_path = Path(base_path) / app_name - - storage_path.mkdir(parents=True, exist_ok=True) - - return storage_path - - def save_secure_file(self, filename: str, content: bytes) -> None: - """ - Save the content to a secure file. - - :param filename: The name of the file. - :param content: The content to save. - """ - storage_path = self.get_secure_storage_path() - file_path = storage_path / filename - - with open(file_path, "wb") as f: - f.write(content) - - # Set appropriate permissions (read/write for owner only) - os.chmod(file_path, 0o600) - - def read_secure_file(self, filename: str) -> Optional[bytes]: - """ - Read the content of a secure file. - - :param filename: The name of the file. - :return: The content of the file if it exists, otherwise None. - """ - storage_path = self.get_secure_storage_path() - file_path = storage_path / filename - - if not file_path.exists(): - return None - - with open(file_path, "rb") as f: - return f.read() diff --git a/src/crewai/cli/config.py b/src/crewai/cli/config.py index a8da400a8..8eccbbb05 100644 --- a/src/crewai/cli/config.py +++ b/src/crewai/cli/config.py @@ -11,6 +11,7 @@ from crewai.cli.constants import ( CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, ) +from crewai.cli.shared.token_manager import TokenManager DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json" @@ -53,6 +54,7 @@ HIDDEN_SETTINGS_KEYS = [ "tool_repository_password", ] + class Settings(BaseModel): enterprise_base_url: Optional[str] = Field( default=DEFAULT_CLI_SETTINGS["enterprise_base_url"], @@ -74,12 +76,12 @@ class Settings(BaseModel): oauth2_provider: str = Field( description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).", - default=DEFAULT_CLI_SETTINGS["oauth2_provider"] + default=DEFAULT_CLI_SETTINGS["oauth2_provider"], ) oauth2_audience: Optional[str] = Field( description="OAuth2 audience value, typically used to identify the target API or resource.", - default=DEFAULT_CLI_SETTINGS["oauth2_audience"] + default=DEFAULT_CLI_SETTINGS["oauth2_audience"], ) oauth2_client_id: str = Field( @@ -89,7 +91,7 @@ class Settings(BaseModel): oauth2_domain: str = Field( description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.", - default=DEFAULT_CLI_SETTINGS["oauth2_domain"] + default=DEFAULT_CLI_SETTINGS["oauth2_domain"], ) def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data): @@ -116,6 +118,7 @@ class Settings(BaseModel): """Reset all settings to default values""" self._reset_user_settings() self._reset_cli_settings() + self._clear_auth_tokens() self.dump() def dump(self) -> None: @@ -139,3 +142,7 @@ class Settings(BaseModel): """Reset all CLI settings to default values""" for key in CLI_SETTINGS_KEYS: setattr(self, key, DEFAULT_CLI_SETTINGS.get(key)) + + def _clear_auth_tokens(self) -> None: + """Clear all authentication tokens""" + TokenManager().clear_tokens() diff --git a/src/crewai/cli/shared/__init__.py b/src/crewai/cli/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/crewai/cli/shared/token_manager.py b/src/crewai/cli/shared/token_manager.py new file mode 100644 index 000000000..7f00ee61f --- /dev/null +++ b/src/crewai/cli/shared/token_manager.py @@ -0,0 +1,139 @@ +import json +import os +import sys +from datetime import datetime +from pathlib import Path +from typing import Optional +from cryptography.fernet import Fernet + + +class TokenManager: + def __init__(self, file_path: str = "tokens.enc") -> None: + """ + Initialize the TokenManager class. + + :param file_path: The file path to store the encrypted tokens. Default is "tokens.enc". + """ + self.file_path = file_path + self.key = self._get_or_create_key() + self.fernet = Fernet(self.key) + + def _get_or_create_key(self) -> bytes: + """ + Get or create the encryption key. + + :return: The encryption key. + """ + key_filename = "secret.key" + key = self.read_secure_file(key_filename) + + if key is not None: + return key + + new_key = Fernet.generate_key() + self.save_secure_file(key_filename, new_key) + return new_key + + def save_tokens(self, access_token: str, expires_at: int) -> None: + """ + Save the access token and its expiration time. + + :param access_token: The access token to save. + :param expires_at: The UNIX timestamp of the expiration time. + """ + expiration_time = datetime.fromtimestamp(expires_at) + data = { + "access_token": access_token, + "expiration": expiration_time.isoformat(), + } + encrypted_data = self.fernet.encrypt(json.dumps(data).encode()) + self.save_secure_file(self.file_path, encrypted_data) + + def get_token(self) -> Optional[str]: + """ + Get the access token if it is valid and not expired. + + :return: The access token if valid and not expired, otherwise None. + """ + encrypted_data = self.read_secure_file(self.file_path) + + decrypted_data = self.fernet.decrypt(encrypted_data) # type: ignore + data = json.loads(decrypted_data) + + expiration = datetime.fromisoformat(data["expiration"]) + if expiration <= datetime.now(): + return None + + return data["access_token"] + + def clear_tokens(self) -> None: + """ + Clear the tokens. + """ + self.delete_secure_file(self.file_path) + + def get_secure_storage_path(self) -> Path: + """ + Get the secure storage path based on the operating system. + + :return: The secure storage path. + """ + if sys.platform == "win32": + # Windows: Use %LOCALAPPDATA% + base_path = os.environ.get("LOCALAPPDATA") + elif sys.platform == "darwin": + # macOS: Use ~/Library/Application Support + base_path = os.path.expanduser("~/Library/Application Support") + else: + # Linux and other Unix-like: Use ~/.local/share + base_path = os.path.expanduser("~/.local/share") + + app_name = "crewai/credentials" + storage_path = Path(base_path) / app_name + + storage_path.mkdir(parents=True, exist_ok=True) + + return storage_path + + def save_secure_file(self, filename: str, content: bytes) -> None: + """ + Save the content to a secure file. + + :param filename: The name of the file. + :param content: The content to save. + """ + storage_path = self.get_secure_storage_path() + file_path = storage_path / filename + + with open(file_path, "wb") as f: + f.write(content) + + # Set appropriate permissions (read/write for owner only) + os.chmod(file_path, 0o600) + + def read_secure_file(self, filename: str) -> Optional[bytes]: + """ + Read the content of a secure file. + + :param filename: The name of the file. + :return: The content of the file if it exists, otherwise None. + """ + storage_path = self.get_secure_storage_path() + file_path = storage_path / filename + + if not file_path.exists(): + return None + + with open(file_path, "rb") as f: + return f.read() + + def delete_secure_file(self, filename: str) -> None: + """ + Delete the secure file. + + :param filename: The name of the file. + """ + storage_path = self.get_secure_storage_path() + file_path = storage_path / filename + if file_path.exists(): + file_path.unlink(missing_ok=True) diff --git a/tests/cli/authentication/test_utils.py b/tests/cli/authentication/test_utils.py index 505b3a28f..860ec7aae 100644 --- a/tests/cli/authentication/test_utils.py +++ b/tests/cli/authentication/test_utils.py @@ -1,17 +1,14 @@ -import json import jwt import unittest -from datetime import datetime, timedelta from unittest.mock import MagicMock, patch -from cryptography.fernet import Fernet -from crewai.cli.authentication.utils import TokenManager, validate_jwt_token +from crewai.cli.authentication.utils import validate_jwt_token @patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock()) @patch("crewai.cli.authentication.utils.jwt") -class TestValidateToken(unittest.TestCase): +class TestUtils(unittest.TestCase): def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient): mock_jwt.decode.return_value = {"exp": 1719859200} @@ -105,121 +102,3 @@ class TestValidateToken(unittest.TestCase): issuer="https://mock_issuer", audience="app_id_xxxx", ) - - -class TestTokenManager(unittest.TestCase): - @patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key") - def setUp(self, mock_get_key): - mock_get_key.return_value = Fernet.generate_key() - self.token_manager = TokenManager() - - @patch("crewai.cli.authentication.utils.TokenManager.read_secure_file") - @patch("crewai.cli.authentication.utils.TokenManager.save_secure_file") - @patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key") - def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read): - mock_key = Fernet.generate_key() - mock_get_or_create.return_value = mock_key - - token_manager = TokenManager() - result = token_manager.key - - self.assertEqual(result, mock_key) - - @patch("crewai.cli.authentication.utils.Fernet.generate_key") - @patch("crewai.cli.authentication.utils.TokenManager.read_secure_file") - @patch("crewai.cli.authentication.utils.TokenManager.save_secure_file") - def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate): - mock_key = b"new_key" - mock_read.return_value = None - mock_generate.return_value = mock_key - - result = self.token_manager._get_or_create_key() - - self.assertEqual(result, mock_key) - mock_read.assert_called_once_with("secret.key") - mock_generate.assert_called_once() - mock_save.assert_called_once_with("secret.key", mock_key) - - @patch("crewai.cli.authentication.utils.TokenManager.save_secure_file") - def test_save_tokens(self, mock_save): - access_token = "test_token" - expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp()) - - self.token_manager.save_tokens(access_token, expires_at) - - mock_save.assert_called_once() - args = mock_save.call_args[0] - self.assertEqual(args[0], "tokens.enc") - decrypted_data = self.token_manager.fernet.decrypt(args[1]) - data = json.loads(decrypted_data) - self.assertEqual(data["access_token"], access_token) - expiration = datetime.fromisoformat(data["expiration"]) - self.assertEqual(expiration, datetime.fromtimestamp(expires_at)) - - @patch("crewai.cli.authentication.utils.TokenManager.read_secure_file") - def test_get_token_valid(self, mock_read): - access_token = "test_token" - expiration = (datetime.now() + timedelta(hours=1)).isoformat() - data = {"access_token": access_token, "expiration": expiration} - encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode()) - mock_read.return_value = encrypted_data - - result = self.token_manager.get_token() - - self.assertEqual(result, access_token) - - @patch("crewai.cli.authentication.utils.TokenManager.read_secure_file") - def test_get_token_expired(self, mock_read): - access_token = "test_token" - expiration = (datetime.now() - timedelta(hours=1)).isoformat() - data = {"access_token": access_token, "expiration": expiration} - encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode()) - mock_read.return_value = encrypted_data - - result = self.token_manager.get_token() - - self.assertIsNone(result) - - @patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path") - @patch("builtins.open", new_callable=unittest.mock.mock_open) - @patch("crewai.cli.authentication.utils.os.chmod") - def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path): - mock_path = MagicMock() - mock_get_path.return_value = mock_path - filename = "test_file.txt" - content = b"test_content" - - self.token_manager.save_secure_file(filename, content) - - mock_path.__truediv__.assert_called_once_with(filename) - mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb") - mock_open().write.assert_called_once_with(content) - mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600) - - @patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path") - @patch( - "builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content" - ) - def test_read_secure_file_exists(self, mock_open, mock_get_path): - mock_path = MagicMock() - mock_get_path.return_value = mock_path - mock_path.__truediv__.return_value.exists.return_value = True - filename = "test_file.txt" - - result = self.token_manager.read_secure_file(filename) - - self.assertEqual(result, b"test_content") - mock_path.__truediv__.assert_called_once_with(filename) - mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb") - - @patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path") - def test_read_secure_file_not_exists(self, mock_get_path): - mock_path = MagicMock() - mock_get_path.return_value = mock_path - mock_path.__truediv__.return_value.exists.return_value = False - filename = "test_file.txt" - - result = self.token_manager.read_secure_file(filename) - - self.assertIsNone(result) - mock_path.__truediv__.assert_called_once_with(filename) diff --git a/tests/cli/test_config.py b/tests/cli/test_config.py index a492da54a..09690c470 100644 --- a/tests/cli/test_config.py +++ b/tests/cli/test_config.py @@ -3,6 +3,7 @@ import shutil import tempfile import unittest from pathlib import Path +from unittest.mock import patch, MagicMock from crewai.cli.config import ( Settings, @@ -10,6 +11,8 @@ from crewai.cli.config import ( CLI_SETTINGS_KEYS, DEFAULT_CLI_SETTINGS, ) +from crewai.cli.shared.token_manager import TokenManager +from datetime import datetime, timedelta class TestSettings(unittest.TestCase): @@ -66,7 +69,8 @@ class TestSettings(unittest.TestCase): for key in user_settings.keys(): self.assertEqual(getattr(settings, key), None) - def test_reset_settings(self): + @patch("crewai.cli.config.TokenManager") + def test_reset_settings(self, mock_token_manager): user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS} cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS} @@ -74,6 +78,11 @@ class TestSettings(unittest.TestCase): config_path=self.config_path, **user_settings, **cli_settings ) + mock_token_manager.return_value = MagicMock() + TokenManager().save_tokens( + "aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp() + ) + settings.reset() for key in user_settings.keys(): @@ -81,6 +90,8 @@ class TestSettings(unittest.TestCase): for key in cli_settings.keys(): self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key)) + mock_token_manager.return_value.clear_tokens.assert_called_once() + def test_dump_new_settings(self): settings = Settings( config_path=self.config_path, tool_repository_username="user1" diff --git a/tests/cli/test_token_manager.py b/tests/cli/test_token_manager.py new file mode 100644 index 000000000..ffee827ea --- /dev/null +++ b/tests/cli/test_token_manager.py @@ -0,0 +1,138 @@ +import json +import unittest +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +from cryptography.fernet import Fernet + +from crewai.cli.shared.token_manager import TokenManager + + +class TestTokenManager(unittest.TestCase): + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def setUp(self, mock_get_key): + mock_get_key.return_value = Fernet.generate_key() + self.token_manager = TokenManager() + + @patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file") + @patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file") + @patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key") + def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read): + mock_key = Fernet.generate_key() + mock_get_or_create.return_value = mock_key + + token_manager = TokenManager() + result = token_manager.key + + self.assertEqual(result, mock_key) + + @patch("crewai.cli.shared.token_manager.Fernet.generate_key") + @patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file") + @patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file") + def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate): + mock_key = b"new_key" + mock_read.return_value = None + mock_generate.return_value = mock_key + + result = self.token_manager._get_or_create_key() + + self.assertEqual(result, mock_key) + mock_read.assert_called_once_with("secret.key") + mock_generate.assert_called_once() + mock_save.assert_called_once_with("secret.key", mock_key) + + @patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file") + def test_save_tokens(self, mock_save): + access_token = "test_token" + expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp()) + + self.token_manager.save_tokens(access_token, expires_at) + + mock_save.assert_called_once() + args = mock_save.call_args[0] + self.assertEqual(args[0], "tokens.enc") + decrypted_data = self.token_manager.fernet.decrypt(args[1]) + data = json.loads(decrypted_data) + self.assertEqual(data["access_token"], access_token) + expiration = datetime.fromisoformat(data["expiration"]) + self.assertEqual(expiration, datetime.fromtimestamp(expires_at)) + + @patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file") + def test_get_token_valid(self, mock_read): + access_token = "test_token" + expiration = (datetime.now() + timedelta(hours=1)).isoformat() + data = {"access_token": access_token, "expiration": expiration} + encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode()) + mock_read.return_value = encrypted_data + + result = self.token_manager.get_token() + + self.assertEqual(result, access_token) + + @patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file") + def test_get_token_expired(self, mock_read): + access_token = "test_token" + expiration = (datetime.now() - timedelta(hours=1)).isoformat() + data = {"access_token": access_token, "expiration": expiration} + encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode()) + mock_read.return_value = encrypted_data + + result = self.token_manager.get_token() + + self.assertIsNone(result) + + @patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path") + @patch("builtins.open", new_callable=unittest.mock.mock_open) + @patch("crewai.cli.shared.token_manager.os.chmod") + def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path): + mock_path = MagicMock() + mock_get_path.return_value = mock_path + filename = "test_file.txt" + content = b"test_content" + + self.token_manager.save_secure_file(filename, content) + + mock_path.__truediv__.assert_called_once_with(filename) + mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb") + mock_open().write.assert_called_once_with(content) + mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600) + + @patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path") + @patch( + "builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content" + ) + def test_read_secure_file_exists(self, mock_open, mock_get_path): + mock_path = MagicMock() + mock_get_path.return_value = mock_path + mock_path.__truediv__.return_value.exists.return_value = True + filename = "test_file.txt" + + result = self.token_manager.read_secure_file(filename) + + self.assertEqual(result, b"test_content") + mock_path.__truediv__.assert_called_once_with(filename) + mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb") + + @patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path") + def test_read_secure_file_not_exists(self, mock_get_path): + mock_path = MagicMock() + mock_get_path.return_value = mock_path + mock_path.__truediv__.return_value.exists.return_value = False + filename = "test_file.txt" + + result = self.token_manager.read_secure_file(filename) + + self.assertIsNone(result) + mock_path.__truediv__.assert_called_once_with(filename) + + @patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path") + def test_clear_tokens(self, mock_get_path): + mock_path = MagicMock() + mock_get_path.return_value = mock_path + + self.token_manager.clear_tokens() + + mock_path.__truediv__.assert_called_once_with("tokens.enc") + mock_path.__truediv__.return_value.unlink.assert_called_once_with( + missing_ok=True + ) diff --git a/tests/cli/tools/test_main.py b/tests/cli/tools/test_main.py index aaa188c0a..117526487 100644 --- a/tests/cli/tools/test_main.py +++ b/tests/cli/tools/test_main.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch import pytest from pytest import raises -from crewai.cli.authentication.utils import TokenManager +from crewai.cli.shared.token_manager import TokenManager from crewai.cli.tools.main import ToolCommand