mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 23:28:30 +00:00
- Add custom OAuth2 error classes (OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError) - Implement URL and scope validation in OAuth2Config using pydantic field_validator - Add retry mechanism with exponential backoff (3 attempts, 1s/2s/4s delays) to OAuth2TokenManager - Implement thread safety using threading.Lock for concurrent token access - Add provider validation in LLM integration with _validate_oauth2_provider method - Enhance testing with validation, retry, thread safety, and error class tests - Improve documentation with comprehensive error handling examples and feature descriptions - Maintain backward compatibility while significantly improving robustness and security Co-Authored-By: João <joao@crewai.com>
421 lines
15 KiB
Python
421 lines
15 KiB
Python
import json
|
|
import pytest
|
|
import requests
|
|
import tempfile
|
|
import time
|
|
import threading
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
from crewai.llm import LLM
|
|
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
|
|
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
|
|
from crewai.llms.oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError
|
|
|
|
|
|
class TestOAuth2Config:
|
|
def test_oauth2_config_creation(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider",
|
|
scope="read write"
|
|
)
|
|
assert config.client_id == "test_client"
|
|
assert config.provider_name == "test_provider"
|
|
assert config.scope == "read write"
|
|
|
|
def test_oauth2_config_loader_missing_file(self):
|
|
loader = OAuth2ConfigLoader("nonexistent.json")
|
|
configs = loader.load_config()
|
|
assert configs == {}
|
|
|
|
def test_oauth2_config_loader_valid_config(self):
|
|
config_data = {
|
|
"oauth2_providers": {
|
|
"custom_provider": {
|
|
"client_id": "test_client",
|
|
"client_secret": "test_secret",
|
|
"token_url": "https://example.com/token",
|
|
"scope": "read"
|
|
}
|
|
}
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
json.dump(config_data, f)
|
|
config_path = f.name
|
|
|
|
try:
|
|
loader = OAuth2ConfigLoader(config_path)
|
|
configs = loader.load_config()
|
|
|
|
assert "custom_provider" in configs
|
|
config = configs["custom_provider"]
|
|
assert config.client_id == "test_client"
|
|
assert config.provider_name == "custom_provider"
|
|
finally:
|
|
Path(config_path).unlink()
|
|
|
|
def test_oauth2_config_loader_invalid_json(self):
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
f.write("invalid json")
|
|
config_path = f.name
|
|
|
|
try:
|
|
loader = OAuth2ConfigLoader(config_path)
|
|
with pytest.raises(OAuth2ConfigurationError, match="Invalid OAuth2 configuration"):
|
|
loader.load_config()
|
|
finally:
|
|
Path(config_path).unlink()
|
|
|
|
|
|
class TestOAuth2TokenManager:
|
|
def test_token_acquisition_success(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider"
|
|
)
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"access_token": "test_token_123",
|
|
"token_type": "Bearer",
|
|
"expires_in": 3600
|
|
}
|
|
mock_response.raise_for_status.return_value = None
|
|
|
|
with patch('requests.post', return_value=mock_response) as mock_post:
|
|
manager = OAuth2TokenManager()
|
|
token = manager.get_access_token(config)
|
|
|
|
assert token == "test_token_123"
|
|
mock_post.assert_called_once()
|
|
|
|
def test_token_caching(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider"
|
|
)
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"access_token": "test_token_123",
|
|
"token_type": "Bearer",
|
|
"expires_in": 3600
|
|
}
|
|
mock_response.raise_for_status.return_value = None
|
|
|
|
with patch('requests.post', return_value=mock_response) as mock_post:
|
|
manager = OAuth2TokenManager()
|
|
|
|
token1 = manager.get_access_token(config)
|
|
assert token1 == "test_token_123"
|
|
assert mock_post.call_count == 1
|
|
|
|
token2 = manager.get_access_token(config)
|
|
assert token2 == "test_token_123"
|
|
assert mock_post.call_count == 1
|
|
|
|
def test_token_refresh_on_expiry(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider"
|
|
)
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {
|
|
"access_token": "new_token_456",
|
|
"token_type": "Bearer",
|
|
"expires_in": 3600
|
|
}
|
|
mock_response.raise_for_status.return_value = None
|
|
|
|
with patch('requests.post', return_value=mock_response):
|
|
manager = OAuth2TokenManager()
|
|
|
|
manager._tokens["test_provider"] = {
|
|
"access_token": "old_token",
|
|
"expires_at": time.time() - 100
|
|
}
|
|
|
|
token = manager.get_access_token(config)
|
|
assert token == "new_token_456"
|
|
|
|
def test_token_acquisition_failure(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider"
|
|
)
|
|
|
|
with patch('requests.post', side_effect=requests.RequestException("Network error")):
|
|
with patch('time.sleep'):
|
|
manager = OAuth2TokenManager()
|
|
|
|
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token"):
|
|
manager.get_access_token(config)
|
|
|
|
def test_invalid_token_response(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider"
|
|
)
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"invalid": "response"}
|
|
mock_response.raise_for_status.return_value = None
|
|
|
|
with patch('requests.post', return_value=mock_response):
|
|
manager = OAuth2TokenManager()
|
|
|
|
with pytest.raises(OAuth2AuthenticationError, match="Invalid token response"):
|
|
manager.get_access_token(config)
|
|
|
|
|
|
class TestLLMOAuth2Integration:
|
|
def test_llm_with_oauth2_config(self):
|
|
config_data = {
|
|
"oauth2_providers": {
|
|
"custom": {
|
|
"client_id": "test_client",
|
|
"client_secret": "test_secret",
|
|
"token_url": "https://example.com/token"
|
|
}
|
|
}
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
json.dump(config_data, f)
|
|
config_path = f.name
|
|
|
|
try:
|
|
llm = LLM(
|
|
model="custom/test-model",
|
|
oauth2_config_path=config_path
|
|
)
|
|
|
|
assert "custom" in llm.oauth2_configs
|
|
assert llm.oauth2_configs["custom"].client_id == "test_client"
|
|
finally:
|
|
Path(config_path).unlink()
|
|
|
|
def test_llm_without_oauth2_config(self):
|
|
llm = LLM(model="openai/gpt-3.5-turbo")
|
|
assert llm.oauth2_configs == {}
|
|
|
|
@patch('crewai.llm.litellm.completion')
|
|
def test_llm_oauth2_token_injection(self, mock_completion):
|
|
config_data = {
|
|
"oauth2_providers": {
|
|
"custom": {
|
|
"client_id": "test_client",
|
|
"client_secret": "test_secret",
|
|
"token_url": "https://example.com/token"
|
|
}
|
|
}
|
|
}
|
|
|
|
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
json.dump(config_data, f)
|
|
config_path = f.name
|
|
|
|
try:
|
|
with patch.object(OAuth2TokenManager, 'get_access_token', return_value="oauth_token_123"):
|
|
llm = LLM(
|
|
model="custom/test-model",
|
|
oauth2_config_path=config_path
|
|
)
|
|
|
|
llm.call("test message")
|
|
|
|
call_args = mock_completion.call_args
|
|
assert call_args[1]['api_key'] == "oauth_token_123"
|
|
finally:
|
|
Path(config_path).unlink()
|
|
|
|
@patch('crewai.llm.litellm.completion')
|
|
def test_llm_oauth2_authentication_failure(self, mock_completion):
|
|
config_data = {
|
|
"oauth2_providers": {
|
|
"custom": {
|
|
"client_id": "test_client",
|
|
"client_secret": "test_secret",
|
|
"token_url": "https://example.com/token"
|
|
}
|
|
}
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
json.dump(config_data, f)
|
|
config_path = f.name
|
|
|
|
try:
|
|
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=OAuth2AuthenticationError("Auth failed")):
|
|
llm = LLM(
|
|
model="custom/test-model",
|
|
oauth2_config_path=config_path
|
|
)
|
|
|
|
with pytest.raises(OAuth2AuthenticationError, match="Auth failed"):
|
|
llm.call("test message")
|
|
finally:
|
|
Path(config_path).unlink()
|
|
|
|
@patch('crewai.llm.litellm.completion')
|
|
def test_llm_non_oauth2_provider_unchanged(self, mock_completion):
|
|
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
|
|
|
|
llm = LLM(
|
|
model="openai/gpt-3.5-turbo",
|
|
api_key="original_key"
|
|
)
|
|
|
|
llm.call("test message")
|
|
|
|
call_args = mock_completion.call_args
|
|
assert call_args[1]['api_key'] == "original_key"
|
|
|
|
|
|
class TestOAuth2Validation:
|
|
def test_valid_url_validation(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider"
|
|
)
|
|
assert config.token_url == "https://example.com/token"
|
|
|
|
def test_invalid_url_validation(self):
|
|
with pytest.raises(OAuth2ValidationError, match="Invalid token URL format"):
|
|
OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="not-a-url",
|
|
provider_name="test_provider"
|
|
)
|
|
|
|
def test_valid_scope_validation(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider",
|
|
scope="read write"
|
|
)
|
|
assert config.scope == "read write"
|
|
|
|
def test_invalid_scope_validation(self):
|
|
with pytest.raises(OAuth2ValidationError, match="Invalid scope format"):
|
|
OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider",
|
|
scope="read write"
|
|
)
|
|
|
|
|
|
class TestOAuth2RetryMechanism:
|
|
def test_retry_mechanism_success_on_second_attempt(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider"
|
|
)
|
|
|
|
responses = [
|
|
requests.RequestException("Network error"),
|
|
Mock(json=lambda: {"access_token": "token_123", "expires_in": 3600}, raise_for_status=lambda: None)
|
|
]
|
|
|
|
with patch('requests.post', side_effect=responses):
|
|
with patch('time.sleep'):
|
|
manager = OAuth2TokenManager()
|
|
token = manager.get_access_token(config)
|
|
assert token == "token_123"
|
|
|
|
def test_retry_mechanism_exhausted(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider"
|
|
)
|
|
|
|
with patch('requests.post', side_effect=requests.RequestException("Network error")):
|
|
with patch('time.sleep'):
|
|
manager = OAuth2TokenManager()
|
|
|
|
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token.*after 3 attempts"):
|
|
manager.get_access_token(config)
|
|
|
|
|
|
class TestOAuth2ThreadSafety:
|
|
def test_concurrent_token_requests(self):
|
|
config = OAuth2Config(
|
|
client_id="test_client",
|
|
client_secret="test_secret",
|
|
token_url="https://example.com/token",
|
|
provider_name="test_provider"
|
|
)
|
|
|
|
mock_response = Mock()
|
|
mock_response.json.return_value = {"access_token": "token_123", "expires_in": 3600}
|
|
mock_response.raise_for_status.return_value = None
|
|
|
|
call_count = 0
|
|
def mock_post(*args, **kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
time.sleep(0.1)
|
|
return mock_response
|
|
|
|
with patch('requests.post', side_effect=mock_post):
|
|
manager = OAuth2TokenManager()
|
|
|
|
results = []
|
|
def get_token():
|
|
token = manager.get_access_token(config)
|
|
results.append(token)
|
|
|
|
threads = [threading.Thread(target=get_token) for _ in range(5)]
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
assert all(token == "token_123" for token in results)
|
|
assert call_count == 1
|
|
|
|
|
|
class TestOAuth2ErrorClasses:
|
|
def test_oauth2_configuration_error(self):
|
|
error = OAuth2ConfigurationError("Config error")
|
|
assert str(error) == "Config error"
|
|
assert isinstance(error, OAuth2Error)
|
|
|
|
def test_oauth2_authentication_error_with_original(self):
|
|
original = requests.RequestException("Network error")
|
|
error = OAuth2AuthenticationError("Auth failed", original_error=original)
|
|
assert str(error) == "Auth failed"
|
|
assert error.original_error == original
|
|
|
|
def test_oauth2_validation_error(self):
|
|
error = OAuth2ValidationError("Validation failed")
|
|
assert str(error) == "Validation failed"
|
|
assert isinstance(error, OAuth2Error)
|