mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 20:08:29 +00:00
* Remove `crewai signup` command, update docs * Add `Settings.clear()` and clear settings before each login * Add pyjwt * Remove print statement from ToolCommand.login() * Remove auth0 dependency * Update docs
224 lines
9.0 KiB
Python
224 lines
9.0 KiB
Python
import json
|
|
import jwt
|
|
import unittest
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from cryptography.fernet import Fernet
|
|
|
|
from crewai.cli.authentication.utils import TokenManager, validate_jwt_token
|
|
|
|
|
|
@patch("crewai.cli.authentication.utils.PyJWKClient", return_value=MagicMock())
|
|
@patch("crewai.cli.authentication.utils.jwt")
|
|
class TestValidateToken(unittest.TestCase):
|
|
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
|
|
mock_jwt.decode.return_value = {"exp": 1719859200}
|
|
|
|
# Create signing key object mock with a .key attribute
|
|
mock_pyjwkclient.return_value.get_signing_key_from_jwt.return_value = MagicMock(
|
|
key="mock_signing_key"
|
|
)
|
|
|
|
decoded_token = validate_jwt_token(
|
|
jwt_token="aaaaa.bbbbbb.cccccc",
|
|
jwks_url="https://mock_jwks_url",
|
|
issuer="https://mock_issuer",
|
|
audience="app_id_xxxx",
|
|
)
|
|
|
|
mock_jwt.decode.assert_called_once_with(
|
|
"aaaaa.bbbbbb.cccccc",
|
|
"mock_signing_key",
|
|
algorithms=["RS256"],
|
|
audience="app_id_xxxx",
|
|
issuer="https://mock_issuer",
|
|
options={
|
|
"verify_signature": True,
|
|
"verify_exp": True,
|
|
"verify_nbf": True,
|
|
"verify_iat": True,
|
|
"require": ["exp", "iat", "iss", "aud", "sub"],
|
|
},
|
|
)
|
|
mock_pyjwkclient.assert_called_once_with("https://mock_jwks_url")
|
|
self.assertEqual(decoded_token, {"exp": 1719859200})
|
|
|
|
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
|
|
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
|
|
with self.assertRaises(Exception):
|
|
validate_jwt_token(
|
|
jwt_token="aaaaa.bbbbbb.cccccc",
|
|
jwks_url="https://mock_jwks_url",
|
|
issuer="https://mock_issuer",
|
|
audience="app_id_xxxx",
|
|
)
|
|
|
|
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
|
|
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
|
|
with self.assertRaises(Exception):
|
|
validate_jwt_token(
|
|
jwt_token="aaaaa.bbbbbb.cccccc",
|
|
jwks_url="https://mock_jwks_url",
|
|
issuer="https://mock_issuer",
|
|
audience="app_id_xxxx",
|
|
)
|
|
|
|
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
|
|
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
|
|
with self.assertRaises(Exception):
|
|
validate_jwt_token(
|
|
jwt_token="aaaaa.bbbbbb.cccccc",
|
|
jwks_url="https://mock_jwks_url",
|
|
issuer="https://mock_issuer",
|
|
audience="app_id_xxxx",
|
|
)
|
|
|
|
def test_validate_jwt_token_missing_required_claims(
|
|
self, mock_jwt, mock_pyjwkclient
|
|
):
|
|
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
|
|
with self.assertRaises(Exception):
|
|
validate_jwt_token(
|
|
jwt_token="aaaaa.bbbbbb.cccccc",
|
|
jwks_url="https://mock_jwks_url",
|
|
issuer="https://mock_issuer",
|
|
audience="app_id_xxxx",
|
|
)
|
|
|
|
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
|
|
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
|
|
with self.assertRaises(Exception):
|
|
validate_jwt_token(
|
|
jwt_token="aaaaa.bbbbbb.cccccc",
|
|
jwks_url="https://mock_jwks_url",
|
|
issuer="https://mock_issuer",
|
|
audience="app_id_xxxx",
|
|
)
|
|
|
|
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
|
|
mock_jwt.decode.side_effect = jwt.InvalidTokenError
|
|
with self.assertRaises(Exception):
|
|
validate_jwt_token(
|
|
jwt_token="aaaaa.bbbbbb.cccccc",
|
|
jwks_url="https://mock_jwks_url",
|
|
issuer="https://mock_issuer",
|
|
audience="app_id_xxxx",
|
|
)
|
|
|
|
|
|
class TestTokenManager(unittest.TestCase):
|
|
def setUp(self):
|
|
self.token_manager = TokenManager()
|
|
|
|
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
|
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
|
@patch("crewai.cli.authentication.utils.TokenManager._get_or_create_key")
|
|
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
|
|
mock_key = Fernet.generate_key()
|
|
mock_get_or_create.return_value = mock_key
|
|
|
|
token_manager = TokenManager()
|
|
result = token_manager.key
|
|
|
|
self.assertEqual(result, mock_key)
|
|
|
|
@patch("crewai.cli.authentication.utils.Fernet.generate_key")
|
|
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
|
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
|
def test_get_or_create_key_new(self, mock_save, mock_read, mock_generate):
|
|
mock_key = b"new_key"
|
|
mock_read.return_value = None
|
|
mock_generate.return_value = mock_key
|
|
|
|
result = self.token_manager._get_or_create_key()
|
|
|
|
self.assertEqual(result, mock_key)
|
|
mock_read.assert_called_once_with("secret.key")
|
|
mock_generate.assert_called_once()
|
|
mock_save.assert_called_once_with("secret.key", mock_key)
|
|
|
|
@patch("crewai.cli.authentication.utils.TokenManager.save_secure_file")
|
|
def test_save_tokens(self, mock_save):
|
|
access_token = "test_token"
|
|
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
|
|
|
|
self.token_manager.save_tokens(access_token, expires_at)
|
|
|
|
mock_save.assert_called_once()
|
|
args = mock_save.call_args[0]
|
|
self.assertEqual(args[0], "tokens.enc")
|
|
decrypted_data = self.token_manager.fernet.decrypt(args[1])
|
|
data = json.loads(decrypted_data)
|
|
self.assertEqual(data["access_token"], access_token)
|
|
expiration = datetime.fromisoformat(data["expiration"])
|
|
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
|
|
|
|
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
|
def test_get_token_valid(self, mock_read):
|
|
access_token = "test_token"
|
|
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
|
|
data = {"access_token": access_token, "expiration": expiration}
|
|
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
|
mock_read.return_value = encrypted_data
|
|
|
|
result = self.token_manager.get_token()
|
|
|
|
self.assertEqual(result, access_token)
|
|
|
|
@patch("crewai.cli.authentication.utils.TokenManager.read_secure_file")
|
|
def test_get_token_expired(self, mock_read):
|
|
access_token = "test_token"
|
|
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
|
|
data = {"access_token": access_token, "expiration": expiration}
|
|
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
|
mock_read.return_value = encrypted_data
|
|
|
|
result = self.token_manager.get_token()
|
|
|
|
self.assertIsNone(result)
|
|
|
|
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
|
@patch("builtins.open", new_callable=unittest.mock.mock_open)
|
|
@patch("crewai.cli.authentication.utils.os.chmod")
|
|
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
|
|
mock_path = MagicMock()
|
|
mock_get_path.return_value = mock_path
|
|
filename = "test_file.txt"
|
|
content = b"test_content"
|
|
|
|
self.token_manager.save_secure_file(filename, content)
|
|
|
|
mock_path.__truediv__.assert_called_once_with(filename)
|
|
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb")
|
|
mock_open().write.assert_called_once_with(content)
|
|
mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600)
|
|
|
|
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
|
@patch(
|
|
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
|
|
)
|
|
def test_read_secure_file_exists(self, mock_open, mock_get_path):
|
|
mock_path = MagicMock()
|
|
mock_get_path.return_value = mock_path
|
|
mock_path.__truediv__.return_value.exists.return_value = True
|
|
filename = "test_file.txt"
|
|
|
|
result = self.token_manager.read_secure_file(filename)
|
|
|
|
self.assertEqual(result, b"test_content")
|
|
mock_path.__truediv__.assert_called_once_with(filename)
|
|
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
|
|
|
|
@patch("crewai.cli.authentication.utils.TokenManager.get_secure_storage_path")
|
|
def test_read_secure_file_not_exists(self, mock_get_path):
|
|
mock_path = MagicMock()
|
|
mock_get_path.return_value = mock_path
|
|
mock_path.__truediv__.return_value.exists.return_value = False
|
|
filename = "test_file.txt"
|
|
|
|
result = self.token_manager.read_secure_file(filename)
|
|
|
|
self.assertIsNone(result)
|
|
mock_path.__truediv__.assert_called_once_with(filename)
|