Implement comprehensive OAuth2 improvements based on code review feedback

- Add custom OAuth2 error classes (OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError)
- Implement URL and scope validation in OAuth2Config using pydantic field_validator
- Add retry mechanism with exponential backoff (3 attempts, 1s/2s/4s delays) to OAuth2TokenManager
- Implement thread safety using threading.Lock for concurrent token access
- Add provider validation in LLM integration with _validate_oauth2_provider method
- Enhance testing with validation, retry, thread safety, and error class tests
- Improve documentation with comprehensive error handling examples and feature descriptions
- Maintain backward compatibility while significantly improving robustness and security

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-07-07 18:10:30 +00:00
parent 03ee4c59eb
commit 4379ad26d1
7 changed files with 322 additions and 22 deletions

View File

@@ -3,11 +3,13 @@ import pytest
import requests
import tempfile
import time
import threading
from pathlib import Path
from unittest.mock import Mock, patch
from crewai.llm import LLM
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
from crewai.llms.oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError
class TestOAuth2Config:
@@ -62,7 +64,7 @@ class TestOAuth2Config:
try:
loader = OAuth2ConfigLoader(config_path)
with pytest.raises(ValueError, match="Invalid OAuth2 configuration"):
with pytest.raises(OAuth2ConfigurationError, match="Invalid OAuth2 configuration"):
loader.load_config()
finally:
Path(config_path).unlink()
@@ -155,10 +157,11 @@ class TestOAuth2TokenManager:
)
with patch('requests.post', side_effect=requests.RequestException("Network error")):
manager = OAuth2TokenManager()
with pytest.raises(RuntimeError, match="Failed to acquire OAuth2 token"):
manager.get_access_token(config)
with patch('time.sleep'):
manager = OAuth2TokenManager()
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token"):
manager.get_access_token(config)
def test_invalid_token_response(self):
config = OAuth2Config(
@@ -175,7 +178,7 @@ class TestOAuth2TokenManager:
with patch('requests.post', return_value=mock_response):
manager = OAuth2TokenManager()
with pytest.raises(RuntimeError, match="Invalid token response"):
with pytest.raises(OAuth2AuthenticationError, match="Invalid token response"):
manager.get_access_token(config)
@@ -259,13 +262,13 @@ class TestLLMOAuth2Integration:
config_path = f.name
try:
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=RuntimeError("Auth failed")):
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=OAuth2AuthenticationError("Auth failed")):
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
with pytest.raises(RuntimeError, match="Auth failed"):
with pytest.raises(OAuth2AuthenticationError, match="Auth failed"):
llm.call("test message")
finally:
Path(config_path).unlink()
@@ -283,3 +286,135 @@ class TestLLMOAuth2Integration:
call_args = mock_completion.call_args
assert call_args[1]['api_key'] == "original_key"
class TestOAuth2Validation:
def test_valid_url_validation(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
assert config.token_url == "https://example.com/token"
def test_invalid_url_validation(self):
with pytest.raises(OAuth2ValidationError, match="Invalid token URL format"):
OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="not-a-url",
provider_name="test_provider"
)
def test_valid_scope_validation(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider",
scope="read write"
)
assert config.scope == "read write"
def test_invalid_scope_validation(self):
with pytest.raises(OAuth2ValidationError, match="Invalid scope format"):
OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider",
scope="read write"
)
class TestOAuth2RetryMechanism:
def test_retry_mechanism_success_on_second_attempt(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
responses = [
requests.RequestException("Network error"),
Mock(json=lambda: {"access_token": "token_123", "expires_in": 3600}, raise_for_status=lambda: None)
]
with patch('requests.post', side_effect=responses):
with patch('time.sleep'):
manager = OAuth2TokenManager()
token = manager.get_access_token(config)
assert token == "token_123"
def test_retry_mechanism_exhausted(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
with patch('requests.post', side_effect=requests.RequestException("Network error")):
with patch('time.sleep'):
manager = OAuth2TokenManager()
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token.*after 3 attempts"):
manager.get_access_token(config)
class TestOAuth2ThreadSafety:
def test_concurrent_token_requests(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {"access_token": "token_123", "expires_in": 3600}
mock_response.raise_for_status.return_value = None
call_count = 0
def mock_post(*args, **kwargs):
nonlocal call_count
call_count += 1
time.sleep(0.1)
return mock_response
with patch('requests.post', side_effect=mock_post):
manager = OAuth2TokenManager()
results = []
def get_token():
token = manager.get_access_token(config)
results.append(token)
threads = [threading.Thread(target=get_token) for _ in range(5)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert all(token == "token_123" for token in results)
assert call_count == 1
class TestOAuth2ErrorClasses:
def test_oauth2_configuration_error(self):
error = OAuth2ConfigurationError("Config error")
assert str(error) == "Config error"
assert isinstance(error, OAuth2Error)
def test_oauth2_authentication_error_with_original(self):
original = requests.RequestException("Network error")
error = OAuth2AuthenticationError("Auth failed", original_error=original)
assert str(error) == "Auth failed"
assert error.original_error == original
def test_oauth2_validation_error(self):
error = OAuth2ValidationError("Validation failed")
assert str(error) == "Validation failed"
assert isinstance(error, OAuth2Error)