mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
- Add custom OAuth2 error classes (OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError) - Implement URL and scope validation in OAuth2Config using pydantic field_validator - Add retry mechanism with exponential backoff (3 attempts, 1s/2s/4s delays) to OAuth2TokenManager - Implement thread safety using threading.Lock for concurrent token access - Add provider validation in LLM integration with _validate_oauth2_provider method - Enhance testing with validation, retry, thread safety, and error class tests - Improve documentation with comprehensive error handling examples and feature descriptions - Maintain backward compatibility while significantly improving robustness and security Co-Authored-By: João <joao@crewai.com>
94 lines
3.8 KiB
Python
94 lines
3.8 KiB
Python
import time
|
|
import logging
|
|
import requests
|
|
from threading import Lock
|
|
from typing import Dict, Any
|
|
from .oauth2_config import OAuth2Config
|
|
from .oauth2_errors import OAuth2AuthenticationError
|
|
|
|
|
|
class OAuth2TokenManager:
|
|
def __init__(self):
|
|
self._tokens: Dict[str, Dict[str, Any]] = {}
|
|
self._lock = Lock()
|
|
|
|
def get_access_token(self, config: OAuth2Config) -> str:
|
|
"""Get valid access token for the provider, refreshing if necessary"""
|
|
with self._lock:
|
|
return self._get_access_token_internal(config)
|
|
|
|
def _get_access_token_internal(self, config: OAuth2Config) -> str:
|
|
"""Internal method to get access token (called within lock)"""
|
|
provider_name = config.provider_name
|
|
|
|
if provider_name in self._tokens:
|
|
token_data = self._tokens[provider_name]
|
|
if self._is_token_valid(token_data):
|
|
logging.debug(f"Using cached OAuth2 token for provider {provider_name}")
|
|
return token_data["access_token"]
|
|
|
|
logging.info(f"Acquiring new OAuth2 token for provider {provider_name}")
|
|
return self._acquire_new_token(config)
|
|
|
|
def _is_token_valid(self, token_data: Dict[str, Any]) -> bool:
|
|
"""Check if token is still valid (not expired)"""
|
|
if "expires_at" not in token_data:
|
|
return False
|
|
|
|
return time.time() < (token_data["expires_at"] - 60)
|
|
|
|
def _acquire_new_token(self, config: OAuth2Config, retry_count: int = 3) -> str:
|
|
"""Acquire new access token using client credentials flow with retry logic"""
|
|
for attempt in range(retry_count):
|
|
try:
|
|
return self._perform_token_request(config)
|
|
except requests.RequestException as e:
|
|
if attempt == retry_count - 1:
|
|
raise OAuth2AuthenticationError(
|
|
f"Failed to acquire OAuth2 token for {config.provider_name} after {retry_count} attempts: {e}",
|
|
original_error=e
|
|
)
|
|
wait_time = 2 ** attempt
|
|
logging.warning(f"OAuth2 token request failed for {config.provider_name}, retrying in {wait_time}s (attempt {attempt + 1}/{retry_count}): {e}")
|
|
time.sleep(wait_time)
|
|
|
|
raise OAuth2AuthenticationError(f"Unexpected error in token acquisition for {config.provider_name}")
|
|
|
|
def _perform_token_request(self, config: OAuth2Config) -> str:
|
|
"""Perform the actual token request"""
|
|
payload = {
|
|
"grant_type": "client_credentials",
|
|
"client_id": config.client_id,
|
|
"client_secret": config.client_secret,
|
|
}
|
|
|
|
if config.scope:
|
|
payload["scope"] = config.scope
|
|
|
|
try:
|
|
response = requests.post(
|
|
config.token_url,
|
|
data=payload,
|
|
timeout=30,
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"}
|
|
)
|
|
response.raise_for_status()
|
|
|
|
token_data = response.json()
|
|
access_token = token_data["access_token"]
|
|
|
|
expires_in = token_data.get("expires_in", 3600)
|
|
self._tokens[config.provider_name] = {
|
|
"access_token": access_token,
|
|
"expires_at": time.time() + expires_in,
|
|
"token_type": token_data.get("token_type", "Bearer")
|
|
}
|
|
|
|
logging.info(f"Successfully acquired OAuth2 token for {config.provider_name}, expires in {expires_in}s")
|
|
return access_token
|
|
|
|
except requests.RequestException:
|
|
raise
|
|
except KeyError as e:
|
|
raise OAuth2AuthenticationError(f"Invalid token response from {config.provider_name}: missing {e}")
|