mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
Implement comprehensive OAuth2 improvements based on code review feedback
- Add custom OAuth2 error classes (OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError) - Implement URL and scope validation in OAuth2Config using pydantic field_validator - Add retry mechanism with exponential backoff (3 attempts, 1s/2s/4s delays) to OAuth2TokenManager - Implement thread safety using threading.Lock for concurrent token access - Add provider validation in LLM integration with _validate_oauth2_provider method - Enhance testing with validation, retry, thread safety, and error class tests - Improve documentation with comprehensive error handling examples and feature descriptions - Maintain backward compatibility while significantly improving robustness and security Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -2,6 +2,74 @@
|
||||
|
||||
CrewAI supports OAuth2 authentication for custom LiteLLM providers through configuration files.
|
||||
|
||||
## Features
|
||||
|
||||
- **Automatic Token Management**: Handles OAuth2 token acquisition and refresh automatically
|
||||
- **Multiple Provider Support**: Configure multiple OAuth2 providers in a single configuration file
|
||||
- **Secure Credential Storage**: Keep OAuth2 credentials separate from your code
|
||||
- **Seamless Integration**: Works with existing LiteLLM provider configurations
|
||||
- **Error Handling**: Comprehensive error handling for authentication failures
|
||||
- **Retry Mechanism**: Automatic retry with exponential backoff for token acquisition
|
||||
- **Thread Safety**: Concurrent access protection for token caching
|
||||
- **Configuration Validation**: Automatic validation of URLs and scope formats
|
||||
|
||||
## Error Handling
|
||||
|
||||
The OAuth2 implementation provides specific error classes for different failure scenarios:
|
||||
|
||||
### OAuth2ConfigurationError
|
||||
Raised when there are issues with the OAuth2 configuration:
|
||||
- Invalid configuration file format
|
||||
- Missing required fields
|
||||
- Unknown OAuth2 providers
|
||||
|
||||
### OAuth2AuthenticationError
|
||||
Raised when OAuth2 authentication fails:
|
||||
- Network errors during token acquisition
|
||||
- Invalid credentials
|
||||
- Token endpoint errors
|
||||
|
||||
### OAuth2ValidationError
|
||||
Raised when configuration validation fails:
|
||||
- Invalid URL formats
|
||||
- Malformed scope values
|
||||
|
||||
### Example Error Handling
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
from crewai.llms.oauth2_errors import OAuth2ConfigurationError, OAuth2AuthenticationError
|
||||
|
||||
try:
|
||||
llm = LLM(
|
||||
model="my_custom_provider/my-model",
|
||||
oauth2_config_path="./litellm_config.json"
|
||||
)
|
||||
result = llm.call("test message")
|
||||
except OAuth2ConfigurationError as e:
|
||||
print(f"Configuration error: {e}")
|
||||
except OAuth2AuthenticationError as e:
|
||||
print(f"Authentication failed: {e}")
|
||||
if e.original_error:
|
||||
print(f"Original error: {e.original_error}")
|
||||
```
|
||||
|
||||
## Configuration Validation
|
||||
|
||||
The OAuth2 configuration is automatically validated:
|
||||
|
||||
- **URL Validation**: `token_url` must be a valid HTTP/HTTPS URL
|
||||
- **Scope Validation**: `scope` cannot contain empty values when split by spaces
|
||||
- **Required Fields**: `client_id`, `client_secret`, `token_url`, and `provider_name` are required
|
||||
|
||||
## Retry Mechanism
|
||||
|
||||
Token acquisition includes automatic retry with exponential backoff:
|
||||
- 3 retry attempts by default
|
||||
- Exponential backoff: 1s, 2s, 4s delays
|
||||
- Detailed logging of retry attempts
|
||||
- Thread-safe token caching
|
||||
|
||||
## Configuration
|
||||
|
||||
Create a `litellm_config.json` file in your project directory:
|
||||
@@ -98,11 +166,12 @@ The `litellm_config.json` file should follow this schema:
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
## Legacy Error Handling
|
||||
|
||||
- If OAuth2 authentication fails, a `RuntimeError` will be raised with details
|
||||
- Invalid configuration files will raise a `ValueError` with specifics
|
||||
- Network errors during token acquisition are wrapped in `RuntimeError`
|
||||
For backward compatibility, the following error types are still supported:
|
||||
- OAuth2 authentication failures raise `OAuth2AuthenticationError` (previously `RuntimeError`)
|
||||
- Invalid configuration files raise `OAuth2ConfigurationError` (previously `ValueError`)
|
||||
- Network errors during token acquisition are wrapped in `OAuth2AuthenticationError`
|
||||
|
||||
## Examples
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ from typing import TextIO
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.oauth2_config import OAuth2ConfigLoader
|
||||
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
|
||||
from crewai.llms.oauth2_errors import OAuth2ConfigurationError
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
@@ -397,9 +398,11 @@ class LLM(BaseLLM):
|
||||
if provider and provider in self.oauth2_configs:
|
||||
oauth2_config = self.oauth2_configs[provider]
|
||||
try:
|
||||
self._validate_oauth2_provider(provider)
|
||||
access_token = self.oauth2_token_manager.get_access_token(oauth2_config)
|
||||
api_key = access_token
|
||||
except RuntimeError as e:
|
||||
logging.debug(f"Using OAuth2 authentication for provider {provider}")
|
||||
except Exception as e:
|
||||
logging.error(f"OAuth2 authentication failed for provider {provider}: {e}")
|
||||
raise
|
||||
|
||||
@@ -1095,6 +1098,12 @@ class LLM(BaseLLM):
|
||||
return self.model.split("/")[0]
|
||||
return None
|
||||
|
||||
def _validate_oauth2_provider(self, provider: str) -> bool:
|
||||
"""Validate that OAuth2 provider exists in configuration"""
|
||||
if provider not in self.oauth2_configs:
|
||||
raise OAuth2ConfigurationError(f"Unknown OAuth2 provider: {provider}")
|
||||
return True
|
||||
|
||||
def _validate_call_params(self) -> None:
|
||||
"""
|
||||
Validate parameters before making a call. Currently this only checks if
|
||||
|
||||
@@ -3,10 +3,15 @@
|
||||
from .base_llm import BaseLLM
|
||||
from .oauth2_config import OAuth2Config, OAuth2ConfigLoader
|
||||
from .oauth2_token_manager import OAuth2TokenManager
|
||||
from .oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError
|
||||
|
||||
__all__ = [
|
||||
"BaseLLM",
|
||||
"OAuth2Config",
|
||||
"OAuth2ConfigLoader",
|
||||
"OAuth2TokenManager"
|
||||
"OAuth2TokenManager",
|
||||
"OAuth2Error",
|
||||
"OAuth2ConfigurationError",
|
||||
"OAuth2AuthenticationError",
|
||||
"OAuth2ValidationError"
|
||||
]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
import re
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from .oauth2_errors import OAuth2ConfigurationError, OAuth2ValidationError
|
||||
|
||||
|
||||
class OAuth2Config(BaseModel):
|
||||
@@ -12,6 +14,31 @@ class OAuth2Config(BaseModel):
|
||||
provider_name: str = Field(description="Custom provider name")
|
||||
refresh_token: Optional[str] = Field(default=None, description="OAuth2 refresh token")
|
||||
|
||||
@field_validator('token_url')
|
||||
@classmethod
|
||||
def validate_token_url(cls, v: str) -> str:
|
||||
"""Validate that token_url is a valid HTTP/HTTPS URL."""
|
||||
url_pattern = re.compile(
|
||||
r'^https?://' # http:// or https://
|
||||
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain...
|
||||
r'localhost|' # localhost...
|
||||
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
|
||||
r'(?::\d+)?' # optional port
|
||||
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
|
||||
|
||||
if not url_pattern.match(v):
|
||||
raise OAuth2ValidationError(f"Invalid token URL format: {v}")
|
||||
return v
|
||||
|
||||
@field_validator('scope')
|
||||
@classmethod
|
||||
def validate_scope(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate OAuth2 scope format."""
|
||||
if v:
|
||||
if ' ' in v:
|
||||
raise OAuth2ValidationError("Invalid scope format: scope cannot contain empty values")
|
||||
return v
|
||||
|
||||
|
||||
class OAuth2ConfigLoader:
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
@@ -35,4 +62,4 @@ class OAuth2ConfigLoader:
|
||||
|
||||
return oauth2_configs
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
raise ValueError(f"Invalid OAuth2 configuration in {self.config_path}: {e}")
|
||||
raise OAuth2ConfigurationError(f"Invalid OAuth2 configuration in {self.config_path}: {e}")
|
||||
|
||||
26
src/crewai/llms/oauth2_errors.py
Normal file
26
src/crewai/llms/oauth2_errors.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""OAuth2 error classes for CrewAI."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class OAuth2Error(Exception):
|
||||
"""Base exception class for OAuth2 operation errors."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
super().__init__(message)
|
||||
self.original_error = original_error
|
||||
|
||||
|
||||
class OAuth2ConfigurationError(OAuth2Error):
|
||||
"""Exception raised for OAuth2 configuration errors."""
|
||||
pass
|
||||
|
||||
|
||||
class OAuth2AuthenticationError(OAuth2Error):
|
||||
"""Exception raised for OAuth2 authentication failures."""
|
||||
pass
|
||||
|
||||
|
||||
class OAuth2ValidationError(OAuth2Error):
|
||||
"""Exception raised for OAuth2 validation errors."""
|
||||
pass
|
||||
@@ -1,22 +1,33 @@
|
||||
import time
|
||||
import logging
|
||||
import requests
|
||||
from threading import Lock
|
||||
from typing import Dict, Any
|
||||
from .oauth2_config import OAuth2Config
|
||||
from .oauth2_errors import OAuth2AuthenticationError
|
||||
|
||||
|
||||
class OAuth2TokenManager:
|
||||
def __init__(self):
|
||||
self._tokens: Dict[str, Dict[str, any]] = {}
|
||||
self._tokens: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def get_access_token(self, config: OAuth2Config) -> str:
|
||||
"""Get valid access token for the provider, refreshing if necessary"""
|
||||
with self._lock:
|
||||
return self._get_access_token_internal(config)
|
||||
|
||||
def _get_access_token_internal(self, config: OAuth2Config) -> str:
|
||||
"""Internal method to get access token (called within lock)"""
|
||||
provider_name = config.provider_name
|
||||
|
||||
if provider_name in self._tokens:
|
||||
token_data = self._tokens[provider_name]
|
||||
if self._is_token_valid(token_data):
|
||||
logging.debug(f"Using cached OAuth2 token for provider {provider_name}")
|
||||
return token_data["access_token"]
|
||||
|
||||
logging.info(f"Acquiring new OAuth2 token for provider {provider_name}")
|
||||
return self._acquire_new_token(config)
|
||||
|
||||
def _is_token_valid(self, token_data: Dict[str, Any]) -> bool:
|
||||
@@ -26,8 +37,25 @@ class OAuth2TokenManager:
|
||||
|
||||
return time.time() < (token_data["expires_at"] - 60)
|
||||
|
||||
def _acquire_new_token(self, config: OAuth2Config) -> str:
|
||||
"""Acquire new access token using client credentials flow"""
|
||||
def _acquire_new_token(self, config: OAuth2Config, retry_count: int = 3) -> str:
|
||||
"""Acquire new access token using client credentials flow with retry logic"""
|
||||
for attempt in range(retry_count):
|
||||
try:
|
||||
return self._perform_token_request(config)
|
||||
except requests.RequestException as e:
|
||||
if attempt == retry_count - 1:
|
||||
raise OAuth2AuthenticationError(
|
||||
f"Failed to acquire OAuth2 token for {config.provider_name} after {retry_count} attempts: {e}",
|
||||
original_error=e
|
||||
)
|
||||
wait_time = 2 ** attempt
|
||||
logging.warning(f"OAuth2 token request failed for {config.provider_name}, retrying in {wait_time}s (attempt {attempt + 1}/{retry_count}): {e}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
raise OAuth2AuthenticationError(f"Unexpected error in token acquisition for {config.provider_name}")
|
||||
|
||||
def _perform_token_request(self, config: OAuth2Config) -> str:
|
||||
"""Perform the actual token request"""
|
||||
payload = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": config.client_id,
|
||||
@@ -56,9 +84,10 @@ class OAuth2TokenManager:
|
||||
"token_type": token_data.get("token_type", "Bearer")
|
||||
}
|
||||
|
||||
logging.info(f"Successfully acquired OAuth2 token for {config.provider_name}, expires in {expires_in}s")
|
||||
return access_token
|
||||
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"Failed to acquire OAuth2 token for {config.provider_name}: {e}")
|
||||
except requests.RequestException:
|
||||
raise
|
||||
except KeyError as e:
|
||||
raise RuntimeError(f"Invalid token response from {config.provider_name}: missing {e}")
|
||||
raise OAuth2AuthenticationError(f"Invalid token response from {config.provider_name}: missing {e}")
|
||||
|
||||
@@ -3,11 +3,13 @@ import pytest
|
||||
import requests
|
||||
import tempfile
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
|
||||
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
|
||||
from crewai.llms.oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError
|
||||
|
||||
|
||||
class TestOAuth2Config:
|
||||
@@ -62,7 +64,7 @@ class TestOAuth2Config:
|
||||
|
||||
try:
|
||||
loader = OAuth2ConfigLoader(config_path)
|
||||
with pytest.raises(ValueError, match="Invalid OAuth2 configuration"):
|
||||
with pytest.raises(OAuth2ConfigurationError, match="Invalid OAuth2 configuration"):
|
||||
loader.load_config()
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
@@ -155,10 +157,11 @@ class TestOAuth2TokenManager:
|
||||
)
|
||||
|
||||
with patch('requests.post', side_effect=requests.RequestException("Network error")):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to acquire OAuth2 token"):
|
||||
manager.get_access_token(config)
|
||||
with patch('time.sleep'):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
def test_invalid_token_response(self):
|
||||
config = OAuth2Config(
|
||||
@@ -175,7 +178,7 @@ class TestOAuth2TokenManager:
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Invalid token response"):
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Invalid token response"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
|
||||
@@ -259,13 +262,13 @@ class TestLLMOAuth2Integration:
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=RuntimeError("Auth failed")):
|
||||
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=OAuth2AuthenticationError("Auth failed")):
|
||||
llm = LLM(
|
||||
model="custom/test-model",
|
||||
oauth2_config_path=config_path
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Auth failed"):
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Auth failed"):
|
||||
llm.call("test message")
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
@@ -283,3 +286,135 @@ class TestLLMOAuth2Integration:
|
||||
|
||||
call_args = mock_completion.call_args
|
||||
assert call_args[1]['api_key'] == "original_key"
|
||||
|
||||
|
||||
class TestOAuth2Validation:
|
||||
def test_valid_url_validation(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
assert config.token_url == "https://example.com/token"
|
||||
|
||||
def test_invalid_url_validation(self):
|
||||
with pytest.raises(OAuth2ValidationError, match="Invalid token URL format"):
|
||||
OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="not-a-url",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
def test_valid_scope_validation(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider",
|
||||
scope="read write"
|
||||
)
|
||||
assert config.scope == "read write"
|
||||
|
||||
def test_invalid_scope_validation(self):
|
||||
with pytest.raises(OAuth2ValidationError, match="Invalid scope format"):
|
||||
OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider",
|
||||
scope="read write"
|
||||
)
|
||||
|
||||
|
||||
class TestOAuth2RetryMechanism:
|
||||
def test_retry_mechanism_success_on_second_attempt(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
responses = [
|
||||
requests.RequestException("Network error"),
|
||||
Mock(json=lambda: {"access_token": "token_123", "expires_in": 3600}, raise_for_status=lambda: None)
|
||||
]
|
||||
|
||||
with patch('requests.post', side_effect=responses):
|
||||
with patch('time.sleep'):
|
||||
manager = OAuth2TokenManager()
|
||||
token = manager.get_access_token(config)
|
||||
assert token == "token_123"
|
||||
|
||||
def test_retry_mechanism_exhausted(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
with patch('requests.post', side_effect=requests.RequestException("Network error")):
|
||||
with patch('time.sleep'):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token.*after 3 attempts"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
|
||||
class TestOAuth2ThreadSafety:
|
||||
def test_concurrent_token_requests(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"access_token": "token_123", "expires_in": 3600}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
call_count = 0
|
||||
def mock_post(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1)
|
||||
return mock_response
|
||||
|
||||
with patch('requests.post', side_effect=mock_post):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
results = []
|
||||
def get_token():
|
||||
token = manager.get_access_token(config)
|
||||
results.append(token)
|
||||
|
||||
threads = [threading.Thread(target=get_token) for _ in range(5)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert all(token == "token_123" for token in results)
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
class TestOAuth2ErrorClasses:
|
||||
def test_oauth2_configuration_error(self):
|
||||
error = OAuth2ConfigurationError("Config error")
|
||||
assert str(error) == "Config error"
|
||||
assert isinstance(error, OAuth2Error)
|
||||
|
||||
def test_oauth2_authentication_error_with_original(self):
|
||||
original = requests.RequestException("Network error")
|
||||
error = OAuth2AuthenticationError("Auth failed", original_error=original)
|
||||
assert str(error) == "Auth failed"
|
||||
assert error.original_error == original
|
||||
|
||||
def test_oauth2_validation_error(self):
|
||||
error = OAuth2ValidationError("Validation failed")
|
||||
assert str(error) == "Validation failed"
|
||||
assert isinstance(error, OAuth2Error)
|
||||
|
||||
Reference in New Issue
Block a user