Files
crewAI/tests/test_oauth2_llm.py
Devin AI ac1080afd8 Implement OAuth2 authentication support for custom LiteLLM providers
- Add OAuth2Config and OAuth2ConfigLoader for litellm_config.json configuration
- Add OAuth2TokenManager for token acquisition, caching, and refresh
- Extend LLM class to support OAuth2 authentication with custom providers
- Add comprehensive tests covering OAuth2 flow and error handling
- Add documentation and usage examples
- Support Client Credentials OAuth2 flow for server-to-server authentication
- Maintain backward compatibility with existing LLM providers

Fixes #3114

Co-Authored-By: João <joao@crewai.com>
2025-07-07 17:57:40 +00:00

286 lines
9.8 KiB
Python

import json
import pytest
import requests
import tempfile
import time
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
from crewai.llm import LLM
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
class TestOAuth2Config:
def test_oauth2_config_creation(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider",
scope="read write"
)
assert config.client_id == "test_client"
assert config.provider_name == "test_provider"
assert config.scope == "read write"
def test_oauth2_config_loader_missing_file(self):
loader = OAuth2ConfigLoader("nonexistent.json")
configs = loader.load_config()
assert configs == {}
def test_oauth2_config_loader_valid_config(self):
config_data = {
"oauth2_providers": {
"custom_provider": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token",
"scope": "read"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
loader = OAuth2ConfigLoader(config_path)
configs = loader.load_config()
assert "custom_provider" in configs
config = configs["custom_provider"]
assert config.client_id == "test_client"
assert config.provider_name == "custom_provider"
finally:
Path(config_path).unlink()
def test_oauth2_config_loader_invalid_json(self):
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write("invalid json")
config_path = f.name
try:
loader = OAuth2ConfigLoader(config_path)
with pytest.raises(ValueError, match="Invalid OAuth2 configuration"):
loader.load_config()
finally:
Path(config_path).unlink()
class TestOAuth2TokenManager:
def test_token_acquisition_success(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "test_token_123",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response) as mock_post:
manager = OAuth2TokenManager()
token = manager.get_access_token(config)
assert token == "test_token_123"
mock_post.assert_called_once()
def test_token_caching(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "test_token_123",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response) as mock_post:
manager = OAuth2TokenManager()
token1 = manager.get_access_token(config)
assert token1 == "test_token_123"
assert mock_post.call_count == 1
token2 = manager.get_access_token(config)
assert token2 == "test_token_123"
assert mock_post.call_count == 1
def test_token_refresh_on_expiry(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "new_token_456",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
manager = OAuth2TokenManager()
manager._tokens["test_provider"] = {
"access_token": "old_token",
"expires_at": time.time() - 100
}
token = manager.get_access_token(config)
assert token == "new_token_456"
def test_token_acquisition_failure(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
with patch('requests.post', side_effect=requests.RequestException("Network error")):
manager = OAuth2TokenManager()
with pytest.raises(RuntimeError, match="Failed to acquire OAuth2 token"):
manager.get_access_token(config)
def test_invalid_token_response(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {"invalid": "response"}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
manager = OAuth2TokenManager()
with pytest.raises(RuntimeError, match="Invalid token response"):
manager.get_access_token(config)
class TestLLMOAuth2Integration:
def test_llm_with_oauth2_config(self):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
assert "custom" in llm.oauth2_configs
assert llm.oauth2_configs["custom"].client_id == "test_client"
finally:
Path(config_path).unlink()
def test_llm_without_oauth2_config(self):
llm = LLM(model="openai/gpt-3.5-turbo")
assert llm.oauth2_configs == {}
@patch('crewai.llm.litellm.completion')
def test_llm_oauth2_token_injection(self, mock_completion):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
with patch.object(OAuth2TokenManager, 'get_access_token', return_value="oauth_token_123"):
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
llm.call("test message")
call_args = mock_completion.call_args
assert call_args[1]['api_key'] == "oauth_token_123"
finally:
Path(config_path).unlink()
@patch('crewai.llm.litellm.completion')
def test_llm_oauth2_authentication_failure(self, mock_completion):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=RuntimeError("Auth failed")):
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
with pytest.raises(RuntimeError, match="Auth failed"):
llm.call("test message")
finally:
Path(config_path).unlink()
@patch('crewai.llm.litellm.completion')
def test_llm_non_oauth2_provider_unchanged(self, mock_completion):
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
llm = LLM(
model="openai/gpt-3.5-turbo",
api_key="original_key"
)
llm.call("test message")
call_args = mock_completion.call_args
assert call_args[1]['api_key'] == "original_key"