Files
crewAI/src/crewai/llms/oauth2_token_manager.py
Devin AI 4379ad26d1 Implement comprehensive OAuth2 improvements based on code review feedback
- Add custom OAuth2 error classes (OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError)
- Implement URL and scope validation in OAuth2Config using pydantic field_validator
- Add retry mechanism with exponential backoff (3 attempts, 1s/2s/4s delays) to OAuth2TokenManager
- Implement thread safety using threading.Lock for concurrent token access
- Add provider validation in LLM integration with _validate_oauth2_provider method
- Enhance testing with validation, retry, thread safety, and error class tests
- Improve documentation with comprehensive error handling examples and feature descriptions
- Maintain backward compatibility while significantly improving robustness and security

Co-Authored-By: João <joao@crewai.com>
2025-07-07 18:10:30 +00:00

94 lines
3.8 KiB
Python

import time
import logging
import requests
from threading import Lock
from typing import Dict, Any
from .oauth2_config import OAuth2Config
from .oauth2_errors import OAuth2AuthenticationError
class OAuth2TokenManager:
def __init__(self):
self._tokens: Dict[str, Dict[str, Any]] = {}
self._lock = Lock()
def get_access_token(self, config: OAuth2Config) -> str:
"""Get valid access token for the provider, refreshing if necessary"""
with self._lock:
return self._get_access_token_internal(config)
def _get_access_token_internal(self, config: OAuth2Config) -> str:
"""Internal method to get access token (called within lock)"""
provider_name = config.provider_name
if provider_name in self._tokens:
token_data = self._tokens[provider_name]
if self._is_token_valid(token_data):
logging.debug(f"Using cached OAuth2 token for provider {provider_name}")
return token_data["access_token"]
logging.info(f"Acquiring new OAuth2 token for provider {provider_name}")
return self._acquire_new_token(config)
def _is_token_valid(self, token_data: Dict[str, Any]) -> bool:
"""Check if token is still valid (not expired)"""
if "expires_at" not in token_data:
return False
return time.time() < (token_data["expires_at"] - 60)
def _acquire_new_token(self, config: OAuth2Config, retry_count: int = 3) -> str:
"""Acquire new access token using client credentials flow with retry logic"""
for attempt in range(retry_count):
try:
return self._perform_token_request(config)
except requests.RequestException as e:
if attempt == retry_count - 1:
raise OAuth2AuthenticationError(
f"Failed to acquire OAuth2 token for {config.provider_name} after {retry_count} attempts: {e}",
original_error=e
)
wait_time = 2 ** attempt
logging.warning(f"OAuth2 token request failed for {config.provider_name}, retrying in {wait_time}s (attempt {attempt + 1}/{retry_count}): {e}")
time.sleep(wait_time)
raise OAuth2AuthenticationError(f"Unexpected error in token acquisition for {config.provider_name}")
def _perform_token_request(self, config: OAuth2Config) -> str:
"""Perform the actual token request"""
payload = {
"grant_type": "client_credentials",
"client_id": config.client_id,
"client_secret": config.client_secret,
}
if config.scope:
payload["scope"] = config.scope
try:
response = requests.post(
config.token_url,
data=payload,
timeout=30,
headers={"Content-Type": "application/x-www-form-urlencoded"}
)
response.raise_for_status()
token_data = response.json()
access_token = token_data["access_token"]
expires_in = token_data.get("expires_in", 3600)
self._tokens[config.provider_name] = {
"access_token": access_token,
"expires_at": time.time() + expires_in,
"token_type": token_data.get("token_type", "Bearer")
}
logging.info(f"Successfully acquired OAuth2 token for {config.provider_name}, expires in {expires_in}s")
return access_token
except requests.RequestException:
raise
except KeyError as e:
raise OAuth2AuthenticationError(f"Invalid token response from {config.provider_name}: missing {e}")