Implement OAuth2 authentication support for custom LiteLLM providers

- Add OAuth2Config and OAuth2ConfigLoader for litellm_config.json configuration
- Add OAuth2TokenManager for token acquisition, caching, and refresh
- Extend LLM class to support OAuth2 authentication with custom providers
- Add comprehensive tests covering OAuth2 flow and error handling
- Add documentation and usage examples
- Support Client Credentials OAuth2 flow for server-to-server authentication
- Maintain backward compatibility with existing LLM providers

Fixes #3114

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-07-07 17:57:40 +00:00
parent a0fcc0c8d1
commit ac1080afd8
7 changed files with 630 additions and 2 deletions

285
tests/test_oauth2_llm.py Normal file
View File

@@ -0,0 +1,285 @@
import json
import pytest
import requests
import tempfile
import time
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
from crewai.llm import LLM
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
class TestOAuth2Config:
def test_oauth2_config_creation(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider",
scope="read write"
)
assert config.client_id == "test_client"
assert config.provider_name == "test_provider"
assert config.scope == "read write"
def test_oauth2_config_loader_missing_file(self):
loader = OAuth2ConfigLoader("nonexistent.json")
configs = loader.load_config()
assert configs == {}
def test_oauth2_config_loader_valid_config(self):
config_data = {
"oauth2_providers": {
"custom_provider": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token",
"scope": "read"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
loader = OAuth2ConfigLoader(config_path)
configs = loader.load_config()
assert "custom_provider" in configs
config = configs["custom_provider"]
assert config.client_id == "test_client"
assert config.provider_name == "custom_provider"
finally:
Path(config_path).unlink()
def test_oauth2_config_loader_invalid_json(self):
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write("invalid json")
config_path = f.name
try:
loader = OAuth2ConfigLoader(config_path)
with pytest.raises(ValueError, match="Invalid OAuth2 configuration"):
loader.load_config()
finally:
Path(config_path).unlink()
class TestOAuth2TokenManager:
def test_token_acquisition_success(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "test_token_123",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response) as mock_post:
manager = OAuth2TokenManager()
token = manager.get_access_token(config)
assert token == "test_token_123"
mock_post.assert_called_once()
def test_token_caching(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "test_token_123",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response) as mock_post:
manager = OAuth2TokenManager()
token1 = manager.get_access_token(config)
assert token1 == "test_token_123"
assert mock_post.call_count == 1
token2 = manager.get_access_token(config)
assert token2 == "test_token_123"
assert mock_post.call_count == 1
def test_token_refresh_on_expiry(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "new_token_456",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
manager = OAuth2TokenManager()
manager._tokens["test_provider"] = {
"access_token": "old_token",
"expires_at": time.time() - 100
}
token = manager.get_access_token(config)
assert token == "new_token_456"
def test_token_acquisition_failure(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
with patch('requests.post', side_effect=requests.RequestException("Network error")):
manager = OAuth2TokenManager()
with pytest.raises(RuntimeError, match="Failed to acquire OAuth2 token"):
manager.get_access_token(config)
def test_invalid_token_response(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {"invalid": "response"}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
manager = OAuth2TokenManager()
with pytest.raises(RuntimeError, match="Invalid token response"):
manager.get_access_token(config)
class TestLLMOAuth2Integration:
def test_llm_with_oauth2_config(self):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
assert "custom" in llm.oauth2_configs
assert llm.oauth2_configs["custom"].client_id == "test_client"
finally:
Path(config_path).unlink()
def test_llm_without_oauth2_config(self):
llm = LLM(model="openai/gpt-3.5-turbo")
assert llm.oauth2_configs == {}
@patch('crewai.llm.litellm.completion')
def test_llm_oauth2_token_injection(self, mock_completion):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
with patch.object(OAuth2TokenManager, 'get_access_token', return_value="oauth_token_123"):
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
llm.call("test message")
call_args = mock_completion.call_args
assert call_args[1]['api_key'] == "oauth_token_123"
finally:
Path(config_path).unlink()
@patch('crewai.llm.litellm.completion')
def test_llm_oauth2_authentication_failure(self, mock_completion):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=RuntimeError("Auth failed")):
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
with pytest.raises(RuntimeError, match="Auth failed"):
llm.call("test message")
finally:
Path(config_path).unlink()
@patch('crewai.llm.litellm.completion')
def test_llm_non_oauth2_provider_unchanged(self, mock_completion):
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
llm = LLM(
model="openai/gpt-3.5-turbo",
api_key="original_key"
)
llm.call("test message")
call_args = mock_completion.call_args
assert call_args[1]['api_key'] == "original_key"