fix: ensure token store file ops do not deadlock

* fix: ensure token store file ops do not deadlock
* chore: update test method reference
This commit is contained in:
Greyson LaLonde
2025-12-08 19:04:21 -05:00
committed by GitHub
parent 6125b866fd
commit beef712646
3 changed files with 331 additions and 215 deletions

View File

@@ -3,103 +3,56 @@ import json
import os
from pathlib import Path
import sys
from typing import BinaryIO, cast
import tempfile
from typing import Final, Literal, cast
from cryptography.fernet import Fernet
if sys.platform == "win32":
import msvcrt
else:
import fcntl
_FERNET_KEY_LENGTH: Final[Literal[44]] = 44
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
"""Manages encrypted token storage."""
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
def __init__(self, file_path: str = "tokens.enc") -> None:
"""Initialize the TokenManager.
Args:
file_path: The file path to store encrypted tokens.
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
@staticmethod
def _acquire_lock(file_handle: BinaryIO) -> None:
"""
Acquire an exclusive lock on a file handle.
Args:
file_handle: Open file handle to lock.
"""
if sys.platform == "win32":
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
else:
fcntl.flock(file_handle.fileno(), fcntl.LOCK_EX)
@staticmethod
def _release_lock(file_handle: BinaryIO) -> None:
"""
Release the lock on a file handle.
Args:
file_handle: Open file handle to unlock.
"""
if sys.platform == "win32":
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
else:
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key with file locking to prevent race conditions.
"""Get or create the encryption key.
Returns:
The encryption key.
The encryption key as bytes.
"""
key_filename = "secret.key"
storage_path = self.get_secure_storage_path()
key_filename: str = "secret.key"
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key
lock_file_path = storage_path / f"{key_filename}.lock"
try:
lock_file_path.touch()
with open(lock_file_path, "r+b") as lock_file:
self._acquire_lock(lock_file)
try:
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
finally:
try:
self._release_lock(lock_file)
except OSError:
pass
except OSError:
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
new_key = Fernet.generate_key()
if self._atomic_create_secure_file(key_filename, new_key):
return new_key
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
Save the access token and its expiration time.
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key
:param access_token: The access token to save.
:param expires_at: The UNIX timestamp of the expiration time.
raise RuntimeError("Failed to create or read encryption key")
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""Save the access token and its expiration time.
Args:
access_token: The access token to save.
expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
@@ -107,15 +60,15 @@ class TokenManager:
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
self._atomic_write_secure_file(self.file_path, encrypted_data)
def get_token(self) -> str | None:
"""
Get the access token if it is valid and not expired.
"""Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
Returns:
The access token if valid and not expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)
encrypted_data = self._read_secure_file(self.file_path)
if encrypted_data is None:
return None
@@ -126,20 +79,18 @@ class TokenManager:
if expiration <= datetime.now():
return None
return cast(str | None, data["access_token"])
return cast(str | None, data.get("access_token"))
def clear_tokens(self) -> None:
"""
Clear the tokens.
"""
self.delete_secure_file(self.file_path)
"""Clear the stored tokens."""
self._delete_secure_file(self.file_path)
@staticmethod
def get_secure_storage_path() -> Path:
"""
Get the secure storage path based on the operating system.
def _get_secure_storage_path() -> Path:
"""Get the secure storage path based on the operating system.
:return: The secure storage path.
Returns:
The secure storage path.
"""
if sys.platform == "win32":
base_path = os.environ.get("LOCALAPPDATA")
@@ -155,44 +106,81 @@ class TokenManager:
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool:
"""Create a file only if it doesn't exist.
:param filename: The name of the file.
:param content: The content to save.
Args:
filename: The name of the file.
content: The content to write.
Returns:
True if file was created, False if it already exists.
"""
storage_path = self.get_secure_storage_path()
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
with open(file_path, "wb") as f:
f.write(content)
try:
fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
try:
os.write(fd, content)
finally:
os.close(fd)
return True
except FileExistsError:
return False
os.chmod(file_path, 0o600)
def _atomic_write_secure_file(self, filename: str, content: bytes) -> None:
"""Write content to a secure file.
def read_secure_file(self, filename: str) -> bytes | None:
Args:
filename: The name of the file.
content: The content to write.
"""
Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.
"""
storage_path = self.get_secure_storage_path()
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
if not file_path.exists():
fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.")
fd_closed = False
try:
os.write(fd, content)
os.close(fd)
fd_closed = True
os.chmod(temp_path, 0o600)
os.replace(temp_path, file_path)
except Exception:
if not fd_closed:
os.close(fd)
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
def _read_secure_file(self, filename: str) -> bytes | None:
"""Read the content of a secure file.
Args:
filename: The name of the file.
Returns:
The content of the file if it exists, otherwise None.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
with open(file_path, "rb") as f:
return f.read()
except FileNotFoundError:
return None
with open(file_path, "rb") as f:
return f.read()
def _delete_secure_file(self, filename: str) -> None:
"""Delete a secure file.
def delete_secure_file(self, filename: str) -> None:
Args:
filename: The name of the file.
"""
Delete the secure file.
:param filename: The name of the file.
"""
storage_path = self.get_secure_storage_path()
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
if file_path.exists():
file_path.unlink(missing_ok=True)
try:
file_path.unlink()
except FileNotFoundError:
pass

View File

@@ -1,7 +1,12 @@
"""Tests for TokenManager with atomic file operations."""
import json
import os
import tempfile
import unittest
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
from pathlib import Path
from unittest.mock import patch
from cryptography.fernet import Fernet
@@ -9,15 +14,22 @@ from crewai.cli.shared.token_manager import TokenManager
class TestTokenManager(unittest.TestCase):
"""Test cases for TokenManager."""
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def setUp(self, mock_get_key):
def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None:
"""Set up test fixtures."""
mock_get_key.return_value = Fernet.generate_key()
self.token_manager = TokenManager()
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_get_or_create_key_existing(self, mock_get_or_create, mock_save, mock_read):
def test_get_or_create_key_existing(
self,
mock_get_or_create: unittest.mock.MagicMock,
mock_read: unittest.mock.MagicMock,
) -> None:
"""Test that existing key is returned when present."""
mock_key = Fernet.generate_key()
mock_get_or_create.return_value = mock_key
@@ -26,40 +38,49 @@ class TestTokenManager(unittest.TestCase):
self.assertEqual(result, mock_key)
@patch("crewai.cli.shared.token_manager.Fernet.generate_key")
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
@patch("crewai.cli.shared.token_manager.TokenManager._acquire_lock")
@patch("crewai.cli.shared.token_manager.TokenManager._release_lock")
@patch("builtins.open", new_callable=unittest.mock.mock_open)
def test_get_or_create_key_new(
self, mock_open, mock_release_lock, mock_acquire_lock, mock_save, mock_read, mock_generate
):
mock_key = b"new_key"
mock_read.return_value = None
mock_generate.return_value = mock_key
def test_get_or_create_key_new(self) -> None:
"""Test that new key is created when none exists."""
mock_key = Fernet.generate_key()
result = self.token_manager._get_or_create_key()
with (
patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read,
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create,
patch("crewai.cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate,
):
result = self.token_manager._get_or_create_key()
self.assertEqual(result, mock_key)
# read_secure_file is called twice: once for fast path, once inside lock
self.assertEqual(mock_read.call_count, 2)
mock_read.assert_called_with("secret.key")
mock_generate.assert_called_once()
mock_save.assert_called_once_with("secret.key", mock_key)
# Verify lock was acquired and released
mock_acquire_lock.assert_called_once()
mock_release_lock.assert_called_once()
self.assertEqual(result, mock_key)
mock_read.assert_called_with("secret.key")
mock_generate.assert_called_once()
mock_atomic_create.assert_called_once_with("secret.key", mock_key)
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
def test_save_tokens(self, mock_save):
def test_get_or_create_key_race_condition(self) -> None:
"""Test that another process's key is used when atomic create fails."""
our_key = Fernet.generate_key()
their_key = Fernet.generate_key()
with (
patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read,
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create,
patch("crewai.cli.shared.token_manager.Fernet.generate_key", return_value=our_key),
):
result = self.token_manager._get_or_create_key()
self.assertEqual(result, their_key)
self.assertEqual(mock_read.call_count, 2)
@patch("crewai.cli.shared.token_manager.TokenManager._atomic_write_secure_file")
def test_save_tokens(
self, mock_write: unittest.mock.MagicMock
) -> None:
"""Test saving tokens encrypts and writes atomically."""
access_token = "test_token"
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
self.token_manager.save_tokens(access_token, expires_at)
mock_save.assert_called_once()
args = mock_save.call_args[0]
mock_write.assert_called_once()
args = mock_write.call_args[0]
self.assertEqual(args[0], "tokens.enc")
decrypted_data = self.token_manager.fernet.decrypt(args[1])
data = json.loads(decrypted_data)
@@ -67,8 +88,11 @@ class TestTokenManager(unittest.TestCase):
expiration = datetime.fromisoformat(data["expiration"])
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
def test_get_token_valid(self, mock_read):
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_valid(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test getting a valid non-expired token."""
access_token = "test_token"
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
@@ -79,8 +103,11 @@ class TestTokenManager(unittest.TestCase):
self.assertEqual(result, access_token)
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
def test_get_token_expired(self, mock_read):
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_expired(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test that expired token returns None."""
access_token = "test_token"
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
data = {"access_token": access_token, "expiration": expiration}
@@ -91,76 +118,177 @@ class TestTokenManager(unittest.TestCase):
self.assertIsNone(result)
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
@patch("builtins.open", new_callable=unittest.mock.mock_open)
@patch("crewai.cli.shared.token_manager.os.chmod")
def test_save_secure_file(self, mock_chmod, mock_open, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
filename = "test_file.txt"
content = b"test_content"
@patch("crewai.cli.shared.token_manager.TokenManager._read_secure_file")
def test_get_token_not_found(
self, mock_read: unittest.mock.MagicMock
) -> None:
"""Test that missing token file returns None."""
mock_read.return_value = None
self.token_manager.save_secure_file(filename, content)
mock_path.__truediv__.assert_called_once_with(filename)
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "wb")
mock_open().write.assert_called_once_with(content)
mock_chmod.assert_called_once_with(mock_path.__truediv__.return_value, 0o600)
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
@patch(
"builtins.open", new_callable=unittest.mock.mock_open, read_data=b"test_content"
)
def test_read_secure_file_exists(self, mock_open, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
mock_path.__truediv__.return_value.exists.return_value = True
filename = "test_file.txt"
result = self.token_manager.read_secure_file(filename)
self.assertEqual(result, b"test_content")
mock_path.__truediv__.assert_called_once_with(filename)
mock_open.assert_called_once_with(mock_path.__truediv__.return_value, "rb")
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
def test_read_secure_file_not_exists(self, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
mock_path.__truediv__.return_value.exists.return_value = False
filename = "test_file.txt"
result = self.token_manager.read_secure_file(filename)
result = self.token_manager.get_token()
self.assertIsNone(result)
mock_path.__truediv__.assert_called_once_with(filename)
@patch("crewai.cli.shared.token_manager.TokenManager.get_secure_storage_path")
def test_clear_tokens(self, mock_get_path):
mock_path = MagicMock()
mock_get_path.return_value = mock_path
@patch("crewai.cli.shared.token_manager.TokenManager._delete_secure_file")
def test_clear_tokens(
self, mock_delete: unittest.mock.MagicMock
) -> None:
"""Test clearing tokens deletes the token file."""
self.token_manager.clear_tokens()
mock_path.__truediv__.assert_called_once_with("tokens.enc")
mock_path.__truediv__.return_value.unlink.assert_called_once_with(
missing_ok=True
)
mock_delete.assert_called_once_with("tokens.enc")
@patch("crewai.cli.shared.token_manager.Fernet.generate_key")
@patch("crewai.cli.shared.token_manager.TokenManager.read_secure_file")
@patch("crewai.cli.shared.token_manager.TokenManager.save_secure_file")
@patch("builtins.open", side_effect=OSError(9, "Bad file descriptor"))
def test_get_or_create_key_oserror_fallback(
self, mock_open, mock_save, mock_read, mock_generate
):
"""Test that OSError during file locking falls back to lock-free creation."""
mock_key = Fernet.generate_key()
mock_read.return_value = None
mock_generate.return_value = mock_key
result = self.token_manager._get_or_create_key()
class TestAtomicFileOperations(unittest.TestCase):
"""Test atomic file operations directly."""
self.assertEqual(result, mock_key)
self.assertGreaterEqual(mock_generate.call_count, 1)
self.assertGreaterEqual(mock_save.call_count, 1)
def setUp(self) -> None:
"""Set up test fixtures with temp directory."""
self.temp_dir = tempfile.mkdtemp()
self.original_get_path = TokenManager._get_secure_storage_path
# Patch to use temp directory
def mock_get_path() -> Path:
return Path(self.temp_dir)
TokenManager._get_secure_storage_path = staticmethod(mock_get_path)
def tearDown(self) -> None:
"""Clean up temp directory."""
TokenManager._get_secure_storage_path = staticmethod(self.original_get_path)
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_create_new_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic create succeeds for new file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
result = tm._atomic_create_secure_file("test.txt", b"content")
self.assertTrue(result)
file_path = Path(self.temp_dir) / "test.txt"
self.assertTrue(file_path.exists())
self.assertEqual(file_path.read_bytes(), b"content")
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_create_existing_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic create fails for existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
# Create file first
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"original")
result = tm._atomic_create_secure_file("test.txt", b"new content")
self.assertFalse(result)
self.assertEqual(file_path.read_bytes(), b"original")
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_new_file(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic write creates new file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
tm._atomic_write_secure_file("test.txt", b"content")
file_path = Path(self.temp_dir) / "test.txt"
self.assertTrue(file_path.exists())
self.assertEqual(file_path.read_bytes(), b"content")
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_overwrites(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test atomic write overwrites existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"original")
tm._atomic_write_secure_file("test.txt", b"new content")
self.assertEqual(file_path.read_bytes(), b"new content")
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_atomic_write_no_temp_file_on_success(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test that temp file is cleaned up after successful write."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
tm._atomic_write_secure_file("test.txt", b"content")
# Check no temp files remain
temp_files = list(Path(self.temp_dir).glob(".test.txt.*"))
self.assertEqual(len(temp_files), 0)
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_read_secure_file_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test reading existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"content")
result = tm._read_secure_file("test.txt")
self.assertEqual(result, b"content")
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_read_secure_file_not_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test reading non-existent file returns None."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
result = tm._read_secure_file("nonexistent.txt")
self.assertIsNone(result)
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_delete_secure_file_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test deleting existing file."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
file_path = Path(self.temp_dir) / "test.txt"
file_path.write_bytes(b"content")
tm._delete_secure_file("test.txt")
self.assertFalse(file_path.exists())
@patch("crewai.cli.shared.token_manager.TokenManager._get_or_create_key")
def test_delete_secure_file_not_exists(
self, mock_get_key: unittest.mock.MagicMock
) -> None:
"""Test deleting non-existent file doesn't raise."""
mock_get_key.return_value = Fernet.generate_key()
tm = TokenManager()
# Should not raise
tm._delete_secure_file("nonexistent.txt")
if __name__ == "__main__":
unittest.main()

View File

@@ -31,7 +31,7 @@ def tool_command():
with tempfile.TemporaryDirectory() as temp_dir:
# Mock the secure storage path to use the temp directory
with patch.object(
TokenManager, "get_secure_storage_path", return_value=Path(temp_dir)
TokenManager, "_get_secure_storage_path", return_value=Path(temp_dir)
):
TokenManager().save_tokens(
"test-token", (datetime.now() + timedelta(seconds=36000)).timestamp()