mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Implement OAuth2 authentication support for custom LiteLLM providers
- Add OAuth2Config and OAuth2ConfigLoader for litellm_config.json configuration - Add OAuth2TokenManager for token acquisition, caching, and refresh - Extend LLM class to support OAuth2 authentication with custom providers - Add comprehensive tests covering OAuth2 flow and error handling - Add documentation and usage examples - Support Client Credentials OAuth2 flow for server-to-server authentication - Maintain backward compatibility with existing LLM providers Fixes #3114 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
147
docs/oauth2_llm_providers.md
Normal file
147
docs/oauth2_llm_providers.md
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# OAuth2 LLM Providers
|
||||||
|
|
||||||
|
CrewAI supports OAuth2 authentication for custom LiteLLM providers through configuration files.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Create a `litellm_config.json` file in your project directory:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"oauth2_providers": {
|
||||||
|
"my_custom_provider": {
|
||||||
|
"client_id": "your_client_id",
|
||||||
|
"client_secret": "your_client_secret",
|
||||||
|
"token_url": "https://your-provider.com/oauth/token",
|
||||||
|
"scope": "llm.read llm.write"
|
||||||
|
},
|
||||||
|
"another_provider": {
|
||||||
|
"client_id": "another_client_id",
|
||||||
|
"client_secret": "another_client_secret",
|
||||||
|
"token_url": "https://another-provider.com/token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import LLM
|
||||||
|
|
||||||
|
# Initialize LLM with OAuth2 support
|
||||||
|
llm = LLM(
|
||||||
|
model="my_custom_provider/my-model",
|
||||||
|
oauth2_config_path="./litellm_config.json" # Optional, defaults to ./litellm_config.json
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use in CrewAI
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Data Analyst",
|
||||||
|
goal="Analyze data trends",
|
||||||
|
backstory="Expert in data analysis",
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Analyze the latest sales data",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
result = crew.kickoff()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
You can also use environment variables in your configuration:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"oauth2_providers": {
|
||||||
|
"my_provider": {
|
||||||
|
"client_id": "os.environ/MY_CLIENT_ID",
|
||||||
|
"client_secret": "os.environ/MY_CLIENT_SECRET",
|
||||||
|
"token_url": "https://my-provider.com/token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported OAuth2 Flow
|
||||||
|
|
||||||
|
Currently supports the **Client Credentials** OAuth2 flow, which is suitable for server-to-server authentication.
|
||||||
|
|
||||||
|
## Token Management
|
||||||
|
|
||||||
|
- Tokens are automatically cached and refreshed when they expire
|
||||||
|
- A 60-second buffer is used before token expiration to ensure reliability
|
||||||
|
- Failed token acquisition will raise a `RuntimeError` with details
|
||||||
|
|
||||||
|
## Configuration Schema
|
||||||
|
|
||||||
|
The `litellm_config.json` file should follow this schema:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"oauth2_providers": {
|
||||||
|
"<provider_name>": {
|
||||||
|
"client_id": "string (required)",
|
||||||
|
"client_secret": "string (required)",
|
||||||
|
"token_url": "string (required)",
|
||||||
|
"scope": "string (optional)",
|
||||||
|
"refresh_token": "string (optional)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
- If OAuth2 authentication fails, a `RuntimeError` will be raised with details
|
||||||
|
- Invalid configuration files will raise a `ValueError` with specifics
|
||||||
|
- Network errors during token acquisition are wrapped in `RuntimeError`
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
### Basic OAuth2 Provider
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import LLM
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="my_provider/gpt-4",
|
||||||
|
oauth2_config_path="./config.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = llm.call("Hello, world!")
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Providers
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"oauth2_providers": {
|
||||||
|
"provider_a": {
|
||||||
|
"client_id": "client_a",
|
||||||
|
"client_secret": "secret_a",
|
||||||
|
"token_url": "https://provider-a.com/token"
|
||||||
|
},
|
||||||
|
"provider_b": {
|
||||||
|
"client_id": "client_b",
|
||||||
|
"client_secret": "secret_b",
|
||||||
|
"token_url": "https://provider-b.com/oauth/token",
|
||||||
|
"scope": "read write"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Use different providers
|
||||||
|
llm_a = LLM(model="provider_a/model-1", oauth2_config_path="./config.json")
|
||||||
|
llm_b = LLM(model="provider_b/model-2", oauth2_config_path="./config.json")
|
||||||
|
```
|
||||||
64
examples/oauth2_llm_example.py
Normal file
64
examples/oauth2_llm_example.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
Example demonstrating OAuth2 authentication with custom LLM providers in CrewAI.
|
||||||
|
|
||||||
|
This example shows how to configure and use OAuth2-authenticated LLM providers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from crewai import Agent, Task, Crew, LLM
|
||||||
|
|
||||||
|
def create_example_config():
|
||||||
|
"""Create an example OAuth2 configuration file."""
|
||||||
|
config = {
|
||||||
|
"oauth2_providers": {
|
||||||
|
"my_custom_provider": {
|
||||||
|
"client_id": "your_client_id_here",
|
||||||
|
"client_secret": "your_client_secret_here",
|
||||||
|
"token_url": "https://your-provider.com/oauth/token",
|
||||||
|
"scope": "llm.read llm.write"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config_path = Path("example_oauth2_config.json")
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Created example config at {config_path}")
|
||||||
|
return config_path
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config_path = create_example_config()
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = LLM(
|
||||||
|
model="my_custom_provider/my-model",
|
||||||
|
oauth2_config_path=str(config_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Research Assistant",
|
||||||
|
goal="Provide helpful research insights",
|
||||||
|
backstory="An AI assistant specialized in research and analysis",
|
||||||
|
llm=llm
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Research the benefits of OAuth2 authentication in AI systems",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="A comprehensive summary of OAuth2 benefits"
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
|
||||||
|
print("Running crew with OAuth2-authenticated LLM...")
|
||||||
|
result = crew.kickoff()
|
||||||
|
print(f"Result: {result}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if config_path.exists():
|
||||||
|
config_path.unlink()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -52,6 +52,8 @@ import io
|
|||||||
from typing import TextIO
|
from typing import TextIO
|
||||||
|
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
from crewai.llms.oauth2_config import OAuth2ConfigLoader
|
||||||
|
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
|
||||||
from crewai.utilities.events import crewai_event_bus
|
from crewai.utilities.events import crewai_event_bus
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededException,
|
LLMContextLengthExceededException,
|
||||||
@@ -311,6 +313,7 @@ class LLM(BaseLLM):
|
|||||||
callbacks: List[Any] = [],
|
callbacks: List[Any] = [],
|
||||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
oauth2_config_path: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -338,6 +341,10 @@ class LLM(BaseLLM):
|
|||||||
self.is_anthropic = self._is_anthropic_model(model)
|
self.is_anthropic = self._is_anthropic_model(model)
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
|
||||||
|
self.oauth2_config_loader = OAuth2ConfigLoader(oauth2_config_path)
|
||||||
|
self.oauth2_token_manager = OAuth2TokenManager()
|
||||||
|
self.oauth2_configs = self.oauth2_config_loader.load_config()
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
# Normalize self.stop to always be a List[str]
|
# Normalize self.stop to always be a List[str]
|
||||||
@@ -384,7 +391,19 @@ class LLM(BaseLLM):
|
|||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
formatted_messages = self._format_messages_for_provider(messages)
|
formatted_messages = self._format_messages_for_provider(messages)
|
||||||
|
|
||||||
# --- 2) Prepare the parameters for the completion call
|
api_key = self.api_key
|
||||||
|
provider = self._get_custom_llm_provider()
|
||||||
|
|
||||||
|
if provider and provider in self.oauth2_configs:
|
||||||
|
oauth2_config = self.oauth2_configs[provider]
|
||||||
|
try:
|
||||||
|
access_token = self.oauth2_token_manager.get_access_token(oauth2_config)
|
||||||
|
api_key = access_token
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.error(f"OAuth2 authentication failed for provider {provider}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# --- 3) Prepare the parameters for the completion call
|
||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": formatted_messages,
|
"messages": formatted_messages,
|
||||||
@@ -404,7 +423,7 @@ class LLM(BaseLLM):
|
|||||||
"api_base": self.api_base,
|
"api_base": self.api_base,
|
||||||
"base_url": self.base_url,
|
"base_url": self.base_url,
|
||||||
"api_version": self.api_version,
|
"api_version": self.api_version,
|
||||||
"api_key": self.api_key,
|
"api_key": api_key,
|
||||||
"stream": self.stream,
|
"stream": self.stream,
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
"reasoning_effort": self.reasoning_effort,
|
"reasoning_effort": self.reasoning_effort,
|
||||||
|
|||||||
@@ -1 +1,12 @@
|
|||||||
"""LLM implementations for crewAI."""
|
"""LLM implementations for crewAI."""
|
||||||
|
|
||||||
|
from .base_llm import BaseLLM
|
||||||
|
from .oauth2_config import OAuth2Config, OAuth2ConfigLoader
|
||||||
|
from .oauth2_token_manager import OAuth2TokenManager
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseLLM",
|
||||||
|
"OAuth2Config",
|
||||||
|
"OAuth2ConfigLoader",
|
||||||
|
"OAuth2TokenManager"
|
||||||
|
]
|
||||||
|
|||||||
38
src/crewai/llms/oauth2_config.py
Normal file
38
src/crewai/llms/oauth2_config.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
import json
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2Config(BaseModel):
|
||||||
|
client_id: str = Field(description="OAuth2 client ID")
|
||||||
|
client_secret: str = Field(description="OAuth2 client secret")
|
||||||
|
token_url: str = Field(description="OAuth2 token endpoint URL")
|
||||||
|
scope: Optional[str] = Field(default=None, description="OAuth2 scope")
|
||||||
|
provider_name: str = Field(description="Custom provider name")
|
||||||
|
refresh_token: Optional[str] = Field(default=None, description="OAuth2 refresh token")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2ConfigLoader:
|
||||||
|
def __init__(self, config_path: Optional[str] = None):
|
||||||
|
self.config_path = Path(config_path) if config_path else Path("litellm_config.json")
|
||||||
|
|
||||||
|
def load_config(self) -> Dict[str, OAuth2Config]:
|
||||||
|
"""Load OAuth2 configurations from litellm_config.json"""
|
||||||
|
if not self.config_path.exists():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(self.config_path, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
oauth2_configs = {}
|
||||||
|
for provider_name, config_data in data.get("oauth2_providers", {}).items():
|
||||||
|
oauth2_configs[provider_name] = OAuth2Config(
|
||||||
|
provider_name=provider_name,
|
||||||
|
**config_data
|
||||||
|
)
|
||||||
|
|
||||||
|
return oauth2_configs
|
||||||
|
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||||
|
raise ValueError(f"Invalid OAuth2 configuration in {self.config_path}: {e}")
|
||||||
64
src/crewai/llms/oauth2_token_manager.py
Normal file
64
src/crewai/llms/oauth2_token_manager.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import time
|
||||||
|
import requests
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from .oauth2_config import OAuth2Config
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2TokenManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._tokens: Dict[str, Dict[str, any]] = {}
|
||||||
|
|
||||||
|
def get_access_token(self, config: OAuth2Config) -> str:
|
||||||
|
"""Get valid access token for the provider, refreshing if necessary"""
|
||||||
|
provider_name = config.provider_name
|
||||||
|
|
||||||
|
if provider_name in self._tokens:
|
||||||
|
token_data = self._tokens[provider_name]
|
||||||
|
if self._is_token_valid(token_data):
|
||||||
|
return token_data["access_token"]
|
||||||
|
|
||||||
|
return self._acquire_new_token(config)
|
||||||
|
|
||||||
|
def _is_token_valid(self, token_data: Dict[str, any]) -> bool:
|
||||||
|
"""Check if token is still valid (not expired)"""
|
||||||
|
if "expires_at" not in token_data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return time.time() < (token_data["expires_at"] - 60)
|
||||||
|
|
||||||
|
def _acquire_new_token(self, config: OAuth2Config) -> str:
|
||||||
|
"""Acquire new access token using client credentials flow"""
|
||||||
|
payload = {
|
||||||
|
"grant_type": "client_credentials",
|
||||||
|
"client_id": config.client_id,
|
||||||
|
"client_secret": config.client_secret,
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.scope:
|
||||||
|
payload["scope"] = config.scope
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
config.token_url,
|
||||||
|
data=payload,
|
||||||
|
timeout=30,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
token_data = response.json()
|
||||||
|
access_token = token_data["access_token"]
|
||||||
|
|
||||||
|
expires_in = token_data.get("expires_in", 3600)
|
||||||
|
self._tokens[config.provider_name] = {
|
||||||
|
"access_token": access_token,
|
||||||
|
"expires_at": time.time() + expires_in,
|
||||||
|
"token_type": token_data.get("token_type", "Bearer")
|
||||||
|
}
|
||||||
|
|
||||||
|
return access_token
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise RuntimeError(f"Failed to acquire OAuth2 token for {config.provider_name}: {e}")
|
||||||
|
except KeyError as e:
|
||||||
|
raise RuntimeError(f"Invalid token response from {config.provider_name}: missing {e}")
|
||||||
285
tests/test_oauth2_llm.py
Normal file
285
tests/test_oauth2_llm.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import Mock, patch, MagicMock
|
||||||
|
from crewai.llm import LLM
|
||||||
|
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
|
||||||
|
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuth2Config:
|
||||||
|
def test_oauth2_config_creation(self):
|
||||||
|
config = OAuth2Config(
|
||||||
|
client_id="test_client",
|
||||||
|
client_secret="test_secret",
|
||||||
|
token_url="https://example.com/token",
|
||||||
|
provider_name="test_provider",
|
||||||
|
scope="read write"
|
||||||
|
)
|
||||||
|
assert config.client_id == "test_client"
|
||||||
|
assert config.provider_name == "test_provider"
|
||||||
|
assert config.scope == "read write"
|
||||||
|
|
||||||
|
def test_oauth2_config_loader_missing_file(self):
|
||||||
|
loader = OAuth2ConfigLoader("nonexistent.json")
|
||||||
|
configs = loader.load_config()
|
||||||
|
assert configs == {}
|
||||||
|
|
||||||
|
def test_oauth2_config_loader_valid_config(self):
|
||||||
|
config_data = {
|
||||||
|
"oauth2_providers": {
|
||||||
|
"custom_provider": {
|
||||||
|
"client_id": "test_client",
|
||||||
|
"client_secret": "test_secret",
|
||||||
|
"token_url": "https://example.com/token",
|
||||||
|
"scope": "read"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||||
|
json.dump(config_data, f)
|
||||||
|
config_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
loader = OAuth2ConfigLoader(config_path)
|
||||||
|
configs = loader.load_config()
|
||||||
|
|
||||||
|
assert "custom_provider" in configs
|
||||||
|
config = configs["custom_provider"]
|
||||||
|
assert config.client_id == "test_client"
|
||||||
|
assert config.provider_name == "custom_provider"
|
||||||
|
finally:
|
||||||
|
Path(config_path).unlink()
|
||||||
|
|
||||||
|
def test_oauth2_config_loader_invalid_json(self):
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||||
|
f.write("invalid json")
|
||||||
|
config_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
loader = OAuth2ConfigLoader(config_path)
|
||||||
|
with pytest.raises(ValueError, match="Invalid OAuth2 configuration"):
|
||||||
|
loader.load_config()
|
||||||
|
finally:
|
||||||
|
Path(config_path).unlink()
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuth2TokenManager:
|
||||||
|
def test_token_acquisition_success(self):
|
||||||
|
config = OAuth2Config(
|
||||||
|
client_id="test_client",
|
||||||
|
client_secret="test_secret",
|
||||||
|
token_url="https://example.com/token",
|
||||||
|
provider_name="test_provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"access_token": "test_token_123",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch('requests.post', return_value=mock_response) as mock_post:
|
||||||
|
manager = OAuth2TokenManager()
|
||||||
|
token = manager.get_access_token(config)
|
||||||
|
|
||||||
|
assert token == "test_token_123"
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
def test_token_caching(self):
|
||||||
|
config = OAuth2Config(
|
||||||
|
client_id="test_client",
|
||||||
|
client_secret="test_secret",
|
||||||
|
token_url="https://example.com/token",
|
||||||
|
provider_name="test_provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"access_token": "test_token_123",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch('requests.post', return_value=mock_response) as mock_post:
|
||||||
|
manager = OAuth2TokenManager()
|
||||||
|
|
||||||
|
token1 = manager.get_access_token(config)
|
||||||
|
assert token1 == "test_token_123"
|
||||||
|
assert mock_post.call_count == 1
|
||||||
|
|
||||||
|
token2 = manager.get_access_token(config)
|
||||||
|
assert token2 == "test_token_123"
|
||||||
|
assert mock_post.call_count == 1
|
||||||
|
|
||||||
|
def test_token_refresh_on_expiry(self):
|
||||||
|
config = OAuth2Config(
|
||||||
|
client_id="test_client",
|
||||||
|
client_secret="test_secret",
|
||||||
|
token_url="https://example.com/token",
|
||||||
|
provider_name="test_provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"access_token": "new_token_456",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch('requests.post', return_value=mock_response):
|
||||||
|
manager = OAuth2TokenManager()
|
||||||
|
|
||||||
|
manager._tokens["test_provider"] = {
|
||||||
|
"access_token": "old_token",
|
||||||
|
"expires_at": time.time() - 100
|
||||||
|
}
|
||||||
|
|
||||||
|
token = manager.get_access_token(config)
|
||||||
|
assert token == "new_token_456"
|
||||||
|
|
||||||
|
def test_token_acquisition_failure(self):
|
||||||
|
config = OAuth2Config(
|
||||||
|
client_id="test_client",
|
||||||
|
client_secret="test_secret",
|
||||||
|
token_url="https://example.com/token",
|
||||||
|
provider_name="test_provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch('requests.post', side_effect=requests.RequestException("Network error")):
|
||||||
|
manager = OAuth2TokenManager()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Failed to acquire OAuth2 token"):
|
||||||
|
manager.get_access_token(config)
|
||||||
|
|
||||||
|
def test_invalid_token_response(self):
|
||||||
|
config = OAuth2Config(
|
||||||
|
client_id="test_client",
|
||||||
|
client_secret="test_secret",
|
||||||
|
token_url="https://example.com/token",
|
||||||
|
provider_name="test_provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {"invalid": "response"}
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
|
||||||
|
with patch('requests.post', return_value=mock_response):
|
||||||
|
manager = OAuth2TokenManager()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Invalid token response"):
|
||||||
|
manager.get_access_token(config)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMOAuth2Integration:
|
||||||
|
def test_llm_with_oauth2_config(self):
|
||||||
|
config_data = {
|
||||||
|
"oauth2_providers": {
|
||||||
|
"custom": {
|
||||||
|
"client_id": "test_client",
|
||||||
|
"client_secret": "test_secret",
|
||||||
|
"token_url": "https://example.com/token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||||
|
json.dump(config_data, f)
|
||||||
|
config_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = LLM(
|
||||||
|
model="custom/test-model",
|
||||||
|
oauth2_config_path=config_path
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "custom" in llm.oauth2_configs
|
||||||
|
assert llm.oauth2_configs["custom"].client_id == "test_client"
|
||||||
|
finally:
|
||||||
|
Path(config_path).unlink()
|
||||||
|
|
||||||
|
def test_llm_without_oauth2_config(self):
|
||||||
|
llm = LLM(model="openai/gpt-3.5-turbo")
|
||||||
|
assert llm.oauth2_configs == {}
|
||||||
|
|
||||||
|
@patch('crewai.llm.litellm.completion')
|
||||||
|
def test_llm_oauth2_token_injection(self, mock_completion):
|
||||||
|
config_data = {
|
||||||
|
"oauth2_providers": {
|
||||||
|
"custom": {
|
||||||
|
"client_id": "test_client",
|
||||||
|
"client_secret": "test_secret",
|
||||||
|
"token_url": "https://example.com/token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||||
|
json.dump(config_data, f)
|
||||||
|
config_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.object(OAuth2TokenManager, 'get_access_token', return_value="oauth_token_123"):
|
||||||
|
llm = LLM(
|
||||||
|
model="custom/test-model",
|
||||||
|
oauth2_config_path=config_path
|
||||||
|
)
|
||||||
|
|
||||||
|
llm.call("test message")
|
||||||
|
|
||||||
|
call_args = mock_completion.call_args
|
||||||
|
assert call_args[1]['api_key'] == "oauth_token_123"
|
||||||
|
finally:
|
||||||
|
Path(config_path).unlink()
|
||||||
|
|
||||||
|
@patch('crewai.llm.litellm.completion')
|
||||||
|
def test_llm_oauth2_authentication_failure(self, mock_completion):
|
||||||
|
config_data = {
|
||||||
|
"oauth2_providers": {
|
||||||
|
"custom": {
|
||||||
|
"client_id": "test_client",
|
||||||
|
"client_secret": "test_secret",
|
||||||
|
"token_url": "https://example.com/token"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||||
|
json.dump(config_data, f)
|
||||||
|
config_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=RuntimeError("Auth failed")):
|
||||||
|
llm = LLM(
|
||||||
|
model="custom/test-model",
|
||||||
|
oauth2_config_path=config_path
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Auth failed"):
|
||||||
|
llm.call("test message")
|
||||||
|
finally:
|
||||||
|
Path(config_path).unlink()
|
||||||
|
|
||||||
|
@patch('crewai.llm.litellm.completion')
|
||||||
|
def test_llm_non_oauth2_provider_unchanged(self, mock_completion):
|
||||||
|
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="openai/gpt-3.5-turbo",
|
||||||
|
api_key="original_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
llm.call("test message")
|
||||||
|
|
||||||
|
call_args = mock_completion.call_args
|
||||||
|
assert call_args[1]['api_key'] == "original_key"
|
||||||
Reference in New Issue
Block a user