mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Implement OAuth2 authentication support for custom LiteLLM providers
- Add OAuth2Config and OAuth2ConfigLoader for litellm_config.json configuration - Add OAuth2TokenManager for token acquisition, caching, and refresh - Extend LLM class to support OAuth2 authentication with custom providers - Add comprehensive tests covering OAuth2 flow and error handling - Add documentation and usage examples - Support Client Credentials OAuth2 flow for server-to-server authentication - Maintain backward compatibility with existing LLM providers Fixes #3114 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
285
tests/test_oauth2_llm.py
Normal file
285
tests/test_oauth2_llm.py
Normal file
@@ -0,0 +1,285 @@
|
||||
import json
|
||||
import pytest
|
||||
import requests
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
|
||||
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
|
||||
|
||||
|
||||
class TestOAuth2Config:
|
||||
def test_oauth2_config_creation(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider",
|
||||
scope="read write"
|
||||
)
|
||||
assert config.client_id == "test_client"
|
||||
assert config.provider_name == "test_provider"
|
||||
assert config.scope == "read write"
|
||||
|
||||
def test_oauth2_config_loader_missing_file(self):
|
||||
loader = OAuth2ConfigLoader("nonexistent.json")
|
||||
configs = loader.load_config()
|
||||
assert configs == {}
|
||||
|
||||
def test_oauth2_config_loader_valid_config(self):
|
||||
config_data = {
|
||||
"oauth2_providers": {
|
||||
"custom_provider": {
|
||||
"client_id": "test_client",
|
||||
"client_secret": "test_secret",
|
||||
"token_url": "https://example.com/token",
|
||||
"scope": "read"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
loader = OAuth2ConfigLoader(config_path)
|
||||
configs = loader.load_config()
|
||||
|
||||
assert "custom_provider" in configs
|
||||
config = configs["custom_provider"]
|
||||
assert config.client_id == "test_client"
|
||||
assert config.provider_name == "custom_provider"
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
def test_oauth2_config_loader_invalid_json(self):
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
f.write("invalid json")
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
loader = OAuth2ConfigLoader(config_path)
|
||||
with pytest.raises(ValueError, match="Invalid OAuth2 configuration"):
|
||||
loader.load_config()
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
|
||||
class TestOAuth2TokenManager:
|
||||
def test_token_acquisition_success(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "test_token_123",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response) as mock_post:
|
||||
manager = OAuth2TokenManager()
|
||||
token = manager.get_access_token(config)
|
||||
|
||||
assert token == "test_token_123"
|
||||
mock_post.assert_called_once()
|
||||
|
||||
def test_token_caching(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "test_token_123",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response) as mock_post:
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
token1 = manager.get_access_token(config)
|
||||
assert token1 == "test_token_123"
|
||||
assert mock_post.call_count == 1
|
||||
|
||||
token2 = manager.get_access_token(config)
|
||||
assert token2 == "test_token_123"
|
||||
assert mock_post.call_count == 1
|
||||
|
||||
def test_token_refresh_on_expiry(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new_token_456",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
manager._tokens["test_provider"] = {
|
||||
"access_token": "old_token",
|
||||
"expires_at": time.time() - 100
|
||||
}
|
||||
|
||||
token = manager.get_access_token(config)
|
||||
assert token == "new_token_456"
|
||||
|
||||
def test_token_acquisition_failure(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
with patch('requests.post', side_effect=requests.RequestException("Network error")):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire OAuth2 token"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
def test_invalid_token_response(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"invalid": "response"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid token response"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
|
||||
class TestLLMOAuth2Integration:
|
||||
def test_llm_with_oauth2_config(self):
|
||||
config_data = {
|
||||
"oauth2_providers": {
|
||||
"custom": {
|
||||
"client_id": "test_client",
|
||||
"client_secret": "test_secret",
|
||||
"token_url": "https://example.com/token"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
llm = LLM(
|
||||
model="custom/test-model",
|
||||
oauth2_config_path=config_path
|
||||
)
|
||||
|
||||
assert "custom" in llm.oauth2_configs
|
||||
assert llm.oauth2_configs["custom"].client_id == "test_client"
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
def test_llm_without_oauth2_config(self):
|
||||
llm = LLM(model="openai/gpt-3.5-turbo")
|
||||
assert llm.oauth2_configs == {}
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_llm_oauth2_token_injection(self, mock_completion):
|
||||
config_data = {
|
||||
"oauth2_providers": {
|
||||
"custom": {
|
||||
"client_id": "test_client",
|
||||
"client_secret": "test_secret",
|
||||
"token_url": "https://example.com/token"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
with patch.object(OAuth2TokenManager, 'get_access_token', return_value="oauth_token_123"):
|
||||
llm = LLM(
|
||||
model="custom/test-model",
|
||||
oauth2_config_path=config_path
|
||||
)
|
||||
|
||||
llm.call("test message")
|
||||
|
||||
call_args = mock_completion.call_args
|
||||
assert call_args[1]['api_key'] == "oauth_token_123"
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_llm_oauth2_authentication_failure(self, mock_completion):
|
||||
config_data = {
|
||||
"oauth2_providers": {
|
||||
"custom": {
|
||||
"client_id": "test_client",
|
||||
"client_secret": "test_secret",
|
||||
"token_url": "https://example.com/token"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=RuntimeError("Auth failed")):
|
||||
llm = LLM(
|
||||
model="custom/test-model",
|
||||
oauth2_config_path=config_path
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Auth failed"):
|
||||
llm.call("test message")
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_llm_non_oauth2_provider_unchanged(self, mock_completion):
|
||||
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
|
||||
|
||||
llm = LLM(
|
||||
model="openai/gpt-3.5-turbo",
|
||||
api_key="original_key"
|
||||
)
|
||||
|
||||
llm.call("test message")
|
||||
|
||||
call_args = mock_completion.call_args
|
||||
assert call_args[1]['api_key'] == "original_key"
|
||||
Reference in New Issue
Block a user