mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Implement comprehensive OAuth2 improvements based on code review feedback
- Add custom OAuth2 error classes (OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError) - Implement URL and scope validation in OAuth2Config using pydantic field_validator - Add retry mechanism with exponential backoff (3 attempts, 1s/2s/4s delays) to OAuth2TokenManager - Implement thread safety using threading.Lock for concurrent token access - Add provider validation in LLM integration with _validate_oauth2_provider method - Enhance testing with validation, retry, thread safety, and error class tests - Improve documentation with comprehensive error handling examples and feature descriptions - Maintain backward compatibility while significantly improving robustness and security Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -3,11 +3,13 @@ import pytest
|
||||
import requests
|
||||
import tempfile
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
|
||||
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
|
||||
from crewai.llms.oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError
|
||||
|
||||
|
||||
class TestOAuth2Config:
|
||||
@@ -62,7 +64,7 @@ class TestOAuth2Config:
|
||||
|
||||
try:
|
||||
loader = OAuth2ConfigLoader(config_path)
|
||||
with pytest.raises(ValueError, match="Invalid OAuth2 configuration"):
|
||||
with pytest.raises(OAuth2ConfigurationError, match="Invalid OAuth2 configuration"):
|
||||
loader.load_config()
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
@@ -155,10 +157,11 @@ class TestOAuth2TokenManager:
|
||||
)
|
||||
|
||||
with patch('requests.post', side_effect=requests.RequestException("Network error")):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire OAuth2 token"):
|
||||
manager.get_access_token(config)
|
||||
with patch('time.sleep'):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
def test_invalid_token_response(self):
|
||||
config = OAuth2Config(
|
||||
@@ -175,7 +178,7 @@ class TestOAuth2TokenManager:
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid token response"):
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Invalid token response"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
|
||||
@@ -259,13 +262,13 @@ class TestLLMOAuth2Integration:
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=RuntimeError("Auth failed")):
|
||||
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=OAuth2AuthenticationError("Auth failed")):
|
||||
llm = LLM(
|
||||
model="custom/test-model",
|
||||
oauth2_config_path=config_path
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Auth failed"):
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Auth failed"):
|
||||
llm.call("test message")
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
@@ -283,3 +286,135 @@ class TestLLMOAuth2Integration:
|
||||
|
||||
call_args = mock_completion.call_args
|
||||
assert call_args[1]['api_key'] == "original_key"
|
||||
|
||||
|
||||
class TestOAuth2Validation:
|
||||
def test_valid_url_validation(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
assert config.token_url == "https://example.com/token"
|
||||
|
||||
def test_invalid_url_validation(self):
|
||||
with pytest.raises(OAuth2ValidationError, match="Invalid token URL format"):
|
||||
OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="not-a-url",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
def test_valid_scope_validation(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider",
|
||||
scope="read write"
|
||||
)
|
||||
assert config.scope == "read write"
|
||||
|
||||
def test_invalid_scope_validation(self):
|
||||
with pytest.raises(OAuth2ValidationError, match="Invalid scope format"):
|
||||
OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider",
|
||||
scope="read write"
|
||||
)
|
||||
|
||||
|
||||
class TestOAuth2RetryMechanism:
|
||||
def test_retry_mechanism_success_on_second_attempt(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
responses = [
|
||||
requests.RequestException("Network error"),
|
||||
Mock(json=lambda: {"access_token": "token_123", "expires_in": 3600}, raise_for_status=lambda: None)
|
||||
]
|
||||
|
||||
with patch('requests.post', side_effect=responses):
|
||||
with patch('time.sleep'):
|
||||
manager = OAuth2TokenManager()
|
||||
token = manager.get_access_token(config)
|
||||
assert token == "token_123"
|
||||
|
||||
def test_retry_mechanism_exhausted(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
with patch('requests.post', side_effect=requests.RequestException("Network error")):
|
||||
with patch('time.sleep'):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token.*after 3 attempts"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
|
||||
class TestOAuth2ThreadSafety:
|
||||
def test_concurrent_token_requests(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"access_token": "token_123", "expires_in": 3600}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
call_count = 0
|
||||
def mock_post(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1)
|
||||
return mock_response
|
||||
|
||||
with patch('requests.post', side_effect=mock_post):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
results = []
|
||||
def get_token():
|
||||
token = manager.get_access_token(config)
|
||||
results.append(token)
|
||||
|
||||
threads = [threading.Thread(target=get_token) for _ in range(5)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert all(token == "token_123" for token in results)
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
class TestOAuth2ErrorClasses:
|
||||
def test_oauth2_configuration_error(self):
|
||||
error = OAuth2ConfigurationError("Config error")
|
||||
assert str(error) == "Config error"
|
||||
assert isinstance(error, OAuth2Error)
|
||||
|
||||
def test_oauth2_authentication_error_with_original(self):
|
||||
original = requests.RequestException("Network error")
|
||||
error = OAuth2AuthenticationError("Auth failed", original_error=original)
|
||||
assert str(error) == "Auth failed"
|
||||
assert error.original_error == original
|
||||
|
||||
def test_oauth2_validation_error(self):
|
||||
error = OAuth2ValidationError("Validation failed")
|
||||
assert str(error) == "Validation failed"
|
||||
assert isinstance(error, OAuth2Error)
|
||||
|
||||
Reference in New Issue
Block a user