Files
crewAI/tests/test_oauth2_llm.py
Devin AI 4379ad26d1 Implement comprehensive OAuth2 improvements based on code review feedback
- Add custom OAuth2 error classes (OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError)
- Implement URL and scope validation in OAuth2Config using pydantic field_validator
- Add retry mechanism with exponential backoff (3 attempts, 1s/2s/4s delays) to OAuth2TokenManager
- Implement thread safety using threading.Lock for concurrent token access
- Add provider validation in LLM integration with _validate_oauth2_provider method
- Enhance testing with validation, retry, thread safety, and error class tests
- Improve documentation with comprehensive error handling examples and feature descriptions
- Maintain backward compatibility while significantly improving robustness and security

Co-Authored-By: João <joao@crewai.com>
2025-07-07 18:10:30 +00:00

421 lines
15 KiB
Python

import json
import pytest
import requests
import tempfile
import time
import threading
from pathlib import Path
from unittest.mock import Mock, patch
from crewai.llm import LLM
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
from crewai.llms.oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError
class TestOAuth2Config:
def test_oauth2_config_creation(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider",
scope="read write"
)
assert config.client_id == "test_client"
assert config.provider_name == "test_provider"
assert config.scope == "read write"
def test_oauth2_config_loader_missing_file(self):
loader = OAuth2ConfigLoader("nonexistent.json")
configs = loader.load_config()
assert configs == {}
def test_oauth2_config_loader_valid_config(self):
config_data = {
"oauth2_providers": {
"custom_provider": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token",
"scope": "read"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
loader = OAuth2ConfigLoader(config_path)
configs = loader.load_config()
assert "custom_provider" in configs
config = configs["custom_provider"]
assert config.client_id == "test_client"
assert config.provider_name == "custom_provider"
finally:
Path(config_path).unlink()
def test_oauth2_config_loader_invalid_json(self):
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write("invalid json")
config_path = f.name
try:
loader = OAuth2ConfigLoader(config_path)
with pytest.raises(OAuth2ConfigurationError, match="Invalid OAuth2 configuration"):
loader.load_config()
finally:
Path(config_path).unlink()
class TestOAuth2TokenManager:
def test_token_acquisition_success(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "test_token_123",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response) as mock_post:
manager = OAuth2TokenManager()
token = manager.get_access_token(config)
assert token == "test_token_123"
mock_post.assert_called_once()
def test_token_caching(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "test_token_123",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response) as mock_post:
manager = OAuth2TokenManager()
token1 = manager.get_access_token(config)
assert token1 == "test_token_123"
assert mock_post.call_count == 1
token2 = manager.get_access_token(config)
assert token2 == "test_token_123"
assert mock_post.call_count == 1
def test_token_refresh_on_expiry(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "new_token_456",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
manager = OAuth2TokenManager()
manager._tokens["test_provider"] = {
"access_token": "old_token",
"expires_at": time.time() - 100
}
token = manager.get_access_token(config)
assert token == "new_token_456"
def test_token_acquisition_failure(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
with patch('requests.post', side_effect=requests.RequestException("Network error")):
with patch('time.sleep'):
manager = OAuth2TokenManager()
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token"):
manager.get_access_token(config)
def test_invalid_token_response(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {"invalid": "response"}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
manager = OAuth2TokenManager()
with pytest.raises(OAuth2AuthenticationError, match="Invalid token response"):
manager.get_access_token(config)
class TestLLMOAuth2Integration:
def test_llm_with_oauth2_config(self):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
assert "custom" in llm.oauth2_configs
assert llm.oauth2_configs["custom"].client_id == "test_client"
finally:
Path(config_path).unlink()
def test_llm_without_oauth2_config(self):
llm = LLM(model="openai/gpt-3.5-turbo")
assert llm.oauth2_configs == {}
@patch('crewai.llm.litellm.completion')
def test_llm_oauth2_token_injection(self, mock_completion):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
with patch.object(OAuth2TokenManager, 'get_access_token', return_value="oauth_token_123"):
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
llm.call("test message")
call_args = mock_completion.call_args
assert call_args[1]['api_key'] == "oauth_token_123"
finally:
Path(config_path).unlink()
@patch('crewai.llm.litellm.completion')
def test_llm_oauth2_authentication_failure(self, mock_completion):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=OAuth2AuthenticationError("Auth failed")):
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
with pytest.raises(OAuth2AuthenticationError, match="Auth failed"):
llm.call("test message")
finally:
Path(config_path).unlink()
@patch('crewai.llm.litellm.completion')
def test_llm_non_oauth2_provider_unchanged(self, mock_completion):
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
llm = LLM(
model="openai/gpt-3.5-turbo",
api_key="original_key"
)
llm.call("test message")
call_args = mock_completion.call_args
assert call_args[1]['api_key'] == "original_key"
class TestOAuth2Validation:
def test_valid_url_validation(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
assert config.token_url == "https://example.com/token"
def test_invalid_url_validation(self):
with pytest.raises(OAuth2ValidationError, match="Invalid token URL format"):
OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="not-a-url",
provider_name="test_provider"
)
def test_valid_scope_validation(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider",
scope="read write"
)
assert config.scope == "read write"
def test_invalid_scope_validation(self):
with pytest.raises(OAuth2ValidationError, match="Invalid scope format"):
OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider",
scope="read write"
)
class TestOAuth2RetryMechanism:
def test_retry_mechanism_success_on_second_attempt(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
responses = [
requests.RequestException("Network error"),
Mock(json=lambda: {"access_token": "token_123", "expires_in": 3600}, raise_for_status=lambda: None)
]
with patch('requests.post', side_effect=responses):
with patch('time.sleep'):
manager = OAuth2TokenManager()
token = manager.get_access_token(config)
assert token == "token_123"
def test_retry_mechanism_exhausted(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
with patch('requests.post', side_effect=requests.RequestException("Network error")):
with patch('time.sleep'):
manager = OAuth2TokenManager()
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token.*after 3 attempts"):
manager.get_access_token(config)
class TestOAuth2ThreadSafety:
def test_concurrent_token_requests(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {"access_token": "token_123", "expires_in": 3600}
mock_response.raise_for_status.return_value = None
call_count = 0
def mock_post(*args, **kwargs):
nonlocal call_count
call_count += 1
time.sleep(0.1)
return mock_response
with patch('requests.post', side_effect=mock_post):
manager = OAuth2TokenManager()
results = []
def get_token():
token = manager.get_access_token(config)
results.append(token)
threads = [threading.Thread(target=get_token) for _ in range(5)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert all(token == "token_123" for token in results)
assert call_count == 1
class TestOAuth2ErrorClasses:
def test_oauth2_configuration_error(self):
error = OAuth2ConfigurationError("Config error")
assert str(error) == "Config error"
assert isinstance(error, OAuth2Error)
def test_oauth2_authentication_error_with_original(self):
original = requests.RequestException("Network error")
error = OAuth2AuthenticationError("Auth failed", original_error=original)
assert str(error) == "Auth failed"
assert error.original_error == original
def test_oauth2_validation_error(self):
error = OAuth2ValidationError("Validation failed")
assert str(error) == "Validation failed"
assert isinstance(error, OAuth2Error)