import json import pytest import requests import tempfile import time import threading from pathlib import Path from unittest.mock import Mock, patch from crewai.llm import LLM from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader from crewai.llms.oauth2_token_manager import OAuth2TokenManager from crewai.llms.oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError class TestOAuth2Config: def test_oauth2_config_creation(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider", scope="read write" ) assert config.client_id == "test_client" assert config.provider_name == "test_provider" assert config.scope == "read write" def test_oauth2_config_loader_missing_file(self): loader = OAuth2ConfigLoader("nonexistent.json") configs = loader.load_config() assert configs == {} def test_oauth2_config_loader_valid_config(self): config_data = { "oauth2_providers": { "custom_provider": { "client_id": "test_client", "client_secret": "test_secret", "token_url": "https://example.com/token", "scope": "read" } } } with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: json.dump(config_data, f) config_path = f.name try: loader = OAuth2ConfigLoader(config_path) configs = loader.load_config() assert "custom_provider" in configs config = configs["custom_provider"] assert config.client_id == "test_client" assert config.provider_name == "custom_provider" finally: Path(config_path).unlink() def test_oauth2_config_loader_invalid_json(self): with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: f.write("invalid json") config_path = f.name try: loader = OAuth2ConfigLoader(config_path) with pytest.raises(OAuth2ConfigurationError, match="Invalid OAuth2 configuration"): loader.load_config() finally: Path(config_path).unlink() class TestOAuth2TokenManager: def test_token_acquisition_success(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider" ) mock_response = Mock() mock_response.json.return_value = { "access_token": "test_token_123", "token_type": "Bearer", "expires_in": 3600 } mock_response.raise_for_status.return_value = None with patch('requests.post', return_value=mock_response) as mock_post: manager = OAuth2TokenManager() token = manager.get_access_token(config) assert token == "test_token_123" mock_post.assert_called_once() def test_token_caching(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider" ) mock_response = Mock() mock_response.json.return_value = { "access_token": "test_token_123", "token_type": "Bearer", "expires_in": 3600 } mock_response.raise_for_status.return_value = None with patch('requests.post', return_value=mock_response) as mock_post: manager = OAuth2TokenManager() token1 = manager.get_access_token(config) assert token1 == "test_token_123" assert mock_post.call_count == 1 token2 = manager.get_access_token(config) assert token2 == "test_token_123" assert mock_post.call_count == 1 def test_token_refresh_on_expiry(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider" ) mock_response = Mock() mock_response.json.return_value = { "access_token": "new_token_456", "token_type": "Bearer", "expires_in": 3600 } mock_response.raise_for_status.return_value = None with patch('requests.post', return_value=mock_response): manager = OAuth2TokenManager() manager._tokens["test_provider"] = { "access_token": "old_token", "expires_at": time.time() - 100 } token = manager.get_access_token(config) assert token == "new_token_456" def test_token_acquisition_failure(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider" ) with patch('requests.post', side_effect=requests.RequestException("Network error")): with patch('time.sleep'): manager = OAuth2TokenManager() with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token"): manager.get_access_token(config) def test_invalid_token_response(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider" ) mock_response = Mock() mock_response.json.return_value = {"invalid": "response"} mock_response.raise_for_status.return_value = None with patch('requests.post', return_value=mock_response): manager = OAuth2TokenManager() with pytest.raises(OAuth2AuthenticationError, match="Invalid token response"): manager.get_access_token(config) class TestLLMOAuth2Integration: def test_llm_with_oauth2_config(self): config_data = { "oauth2_providers": { "custom": { "client_id": "test_client", "client_secret": "test_secret", "token_url": "https://example.com/token" } } } with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: json.dump(config_data, f) config_path = f.name try: llm = LLM( model="custom/test-model", oauth2_config_path=config_path ) assert "custom" in llm.oauth2_configs assert llm.oauth2_configs["custom"].client_id == "test_client" finally: Path(config_path).unlink() def test_llm_without_oauth2_config(self): llm = LLM(model="openai/gpt-3.5-turbo") assert llm.oauth2_configs == {} @patch('crewai.llm.litellm.completion') def test_llm_oauth2_token_injection(self, mock_completion): config_data = { "oauth2_providers": { "custom": { "client_id": "test_client", "client_secret": "test_secret", "token_url": "https://example.com/token" } } } mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))]) with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: json.dump(config_data, f) config_path = f.name try: with patch.object(OAuth2TokenManager, 'get_access_token', return_value="oauth_token_123"): llm = LLM( model="custom/test-model", oauth2_config_path=config_path ) llm.call("test message") call_args = mock_completion.call_args assert call_args[1]['api_key'] == "oauth_token_123" finally: Path(config_path).unlink() @patch('crewai.llm.litellm.completion') def test_llm_oauth2_authentication_failure(self, mock_completion): config_data = { "oauth2_providers": { "custom": { "client_id": "test_client", "client_secret": "test_secret", "token_url": "https://example.com/token" } } } with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: json.dump(config_data, f) config_path = f.name try: with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=OAuth2AuthenticationError("Auth failed")): llm = LLM( model="custom/test-model", oauth2_config_path=config_path ) with pytest.raises(OAuth2AuthenticationError, match="Auth failed"): llm.call("test message") finally: Path(config_path).unlink() @patch('crewai.llm.litellm.completion') def test_llm_non_oauth2_provider_unchanged(self, mock_completion): mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))]) llm = LLM( model="openai/gpt-3.5-turbo", api_key="original_key" ) llm.call("test message") call_args = mock_completion.call_args assert call_args[1]['api_key'] == "original_key" class TestOAuth2Validation: def test_valid_url_validation(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider" ) assert config.token_url == "https://example.com/token" def test_invalid_url_validation(self): with pytest.raises(OAuth2ValidationError, match="Invalid token URL format"): OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="not-a-url", provider_name="test_provider" ) def test_valid_scope_validation(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider", scope="read write" ) assert config.scope == "read write" def test_invalid_scope_validation(self): with pytest.raises(OAuth2ValidationError, match="Invalid scope format"): OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider", scope="read write" ) class TestOAuth2RetryMechanism: def test_retry_mechanism_success_on_second_attempt(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider" ) responses = [ requests.RequestException("Network error"), Mock(json=lambda: {"access_token": "token_123", "expires_in": 3600}, raise_for_status=lambda: None) ] with patch('requests.post', side_effect=responses): with patch('time.sleep'): manager = OAuth2TokenManager() token = manager.get_access_token(config) assert token == "token_123" def test_retry_mechanism_exhausted(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider" ) with patch('requests.post', side_effect=requests.RequestException("Network error")): with patch('time.sleep'): manager = OAuth2TokenManager() with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token.*after 3 attempts"): manager.get_access_token(config) class TestOAuth2ThreadSafety: def test_concurrent_token_requests(self): config = OAuth2Config( client_id="test_client", client_secret="test_secret", token_url="https://example.com/token", provider_name="test_provider" ) mock_response = Mock() mock_response.json.return_value = {"access_token": "token_123", "expires_in": 3600} mock_response.raise_for_status.return_value = None call_count = 0 def mock_post(*args, **kwargs): nonlocal call_count call_count += 1 time.sleep(0.1) return mock_response with patch('requests.post', side_effect=mock_post): manager = OAuth2TokenManager() results = [] def get_token(): token = manager.get_access_token(config) results.append(token) threads = [threading.Thread(target=get_token) for _ in range(5)] for thread in threads: thread.start() for thread in threads: thread.join() assert all(token == "token_123" for token in results) assert call_count == 1 class TestOAuth2ErrorClasses: def test_oauth2_configuration_error(self): error = OAuth2ConfigurationError("Config error") assert str(error) == "Config error" assert isinstance(error, OAuth2Error) def test_oauth2_authentication_error_with_original(self): original = requests.RequestException("Network error") error = OAuth2AuthenticationError("Auth failed", original_error=original) assert str(error) == "Auth failed" assert error.original_error == original def test_oauth2_validation_error(self): error = OAuth2ValidationError("Validation failed") assert str(error) == "Validation failed" assert isinstance(error, OAuth2Error)