diff --git a/docs/oauth2_llm_providers.md b/docs/oauth2_llm_providers.md new file mode 100644 index 000000000..98e34ca7a --- /dev/null +++ b/docs/oauth2_llm_providers.md @@ -0,0 +1,147 @@ +# OAuth2 LLM Providers + +CrewAI supports OAuth2 authentication for custom LiteLLM providers through configuration files. + +## Configuration + +Create a `litellm_config.json` file in your project directory: + +```json +{ + "oauth2_providers": { + "my_custom_provider": { + "client_id": "your_client_id", + "client_secret": "your_client_secret", + "token_url": "https://your-provider.com/oauth/token", + "scope": "llm.read llm.write" + }, + "another_provider": { + "client_id": "another_client_id", + "client_secret": "another_client_secret", + "token_url": "https://another-provider.com/token" + } + } +} +``` + +## Usage + +```python +from crewai import LLM + +# Initialize LLM with OAuth2 support +llm = LLM( + model="my_custom_provider/my-model", + oauth2_config_path="./litellm_config.json" # Optional, defaults to ./litellm_config.json +) + +# Use in CrewAI +from crewai import Agent, Task, Crew + +agent = Agent( + role="Data Analyst", + goal="Analyze data trends", + backstory="Expert in data analysis", + llm=llm +) + +task = Task( + description="Analyze the latest sales data", + agent=agent +) + +crew = Crew(agents=[agent], tasks=[task]) +result = crew.kickoff() +``` + +## Environment Variables + +You can also use environment variables in your configuration: + +```json +{ + "oauth2_providers": { + "my_provider": { + "client_id": "os.environ/MY_CLIENT_ID", + "client_secret": "os.environ/MY_CLIENT_SECRET", + "token_url": "https://my-provider.com/token" + } + } +} +``` + +## Supported OAuth2 Flow + +Currently supports the **Client Credentials** OAuth2 flow, which is suitable for server-to-server authentication. + +## Token Management + +- Tokens are automatically cached and refreshed when they expire +- A 60-second buffer is used before token expiration to ensure reliability +- Failed token acquisition will raise a `RuntimeError` with details + +## Configuration Schema + +The `litellm_config.json` file should follow this schema: + +```json +{ + "oauth2_providers": { + "": { + "client_id": "string (required)", + "client_secret": "string (required)", + "token_url": "string (required)", + "scope": "string (optional)", + "refresh_token": "string (optional)" + } + } +} +``` + +## Error Handling + +- If OAuth2 authentication fails, a `RuntimeError` will be raised with details +- Invalid configuration files will raise a `ValueError` with specifics +- Network errors during token acquisition are wrapped in `RuntimeError` + +## Examples + +### Basic OAuth2 Provider + +```python +from crewai import LLM + +llm = LLM( + model="my_provider/gpt-4", + oauth2_config_path="./config.json" +) + +response = llm.call("Hello, world!") +print(response) +``` + +### Multiple Providers + +```json +{ + "oauth2_providers": { + "provider_a": { + "client_id": "client_a", + "client_secret": "secret_a", + "token_url": "https://provider-a.com/token" + }, + "provider_b": { + "client_id": "client_b", + "client_secret": "secret_b", + "token_url": "https://provider-b.com/oauth/token", + "scope": "read write" + } + } +} +``` + +```python +# Use different providers +llm_a = LLM(model="provider_a/model-1", oauth2_config_path="./config.json") +llm_b = LLM(model="provider_b/model-2", oauth2_config_path="./config.json") +``` diff --git a/examples/oauth2_llm_example.py b/examples/oauth2_llm_example.py new file mode 100644 index 000000000..1e86a6b42 --- /dev/null +++ b/examples/oauth2_llm_example.py @@ -0,0 +1,64 @@ +""" +Example demonstrating OAuth2 authentication with custom LLM providers in CrewAI. + +This example shows how to configure and use OAuth2-authenticated LLM providers. +""" + +import json +from pathlib import Path +from crewai import Agent, Task, Crew, LLM + +def create_example_config(): + """Create an example OAuth2 configuration file.""" + config = { + "oauth2_providers": { + "my_custom_provider": { + "client_id": "your_client_id_here", + "client_secret": "your_client_secret_here", + "token_url": "https://your-provider.com/oauth/token", + "scope": "llm.read llm.write" + } + } + } + + config_path = Path("example_oauth2_config.json") + with open(config_path, 'w') as f: + json.dump(config, f, indent=2) + + print(f"Created example config at {config_path}") + return config_path + +def main(): + config_path = create_example_config() + + try: + llm = LLM( + model="my_custom_provider/my-model", + oauth2_config_path=str(config_path) + ) + + agent = Agent( + role="Research Assistant", + goal="Provide helpful research insights", + backstory="An AI assistant specialized in research and analysis", + llm=llm + ) + + task = Task( + description="Research the benefits of OAuth2 authentication in AI systems", + agent=agent, + expected_output="A comprehensive summary of OAuth2 benefits" + ) + + crew = Crew(agents=[agent], tasks=[task]) + + print("Running crew with OAuth2-authenticated LLM...") + result = crew.kickoff() + print(f"Result: {result}") + + finally: + if config_path.exists(): + config_path.unlink() + +if __name__ == "__main__": + main() diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 88edb5ec5..79861b968 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -52,6 +52,8 @@ import io from typing import TextIO from crewai.llms.base_llm import BaseLLM +from crewai.llms.oauth2_config import OAuth2ConfigLoader +from crewai.llms.oauth2_token_manager import OAuth2TokenManager from crewai.utilities.events import crewai_event_bus from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededException, @@ -311,6 +313,7 @@ class LLM(BaseLLM): callbacks: List[Any] = [], reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, stream: bool = False, + oauth2_config_path: Optional[str] = None, **kwargs, ): self.model = model @@ -338,6 +341,10 @@ class LLM(BaseLLM): self.is_anthropic = self._is_anthropic_model(model) self.stream = stream + self.oauth2_config_loader = OAuth2ConfigLoader(oauth2_config_path) + self.oauth2_token_manager = OAuth2TokenManager() + self.oauth2_configs = self.oauth2_config_loader.load_config() + litellm.drop_params = True # Normalize self.stop to always be a List[str] @@ -384,7 +391,19 @@ class LLM(BaseLLM): messages = [{"role": "user", "content": messages}] formatted_messages = self._format_messages_for_provider(messages) - # --- 2) Prepare the parameters for the completion call + api_key = self.api_key + provider = self._get_custom_llm_provider() + + if provider and provider in self.oauth2_configs: + oauth2_config = self.oauth2_configs[provider] + try: + access_token = self.oauth2_token_manager.get_access_token(oauth2_config) + api_key = access_token + except RuntimeError as e: + logging.error(f"OAuth2 authentication failed for provider {provider}: {e}") + raise + + # --- 3) Prepare the parameters for the completion call params = { "model": self.model, "messages": formatted_messages, @@ -404,7 +423,7 @@ class LLM(BaseLLM): "api_base": self.api_base, "base_url": self.base_url, "api_version": self.api_version, - "api_key": self.api_key, + "api_key": api_key, "stream": self.stream, "tools": tools, "reasoning_effort": self.reasoning_effort, diff --git a/src/crewai/llms/__init__.py b/src/crewai/llms/__init__.py index fda1e6a3b..0cc89748a 100644 --- a/src/crewai/llms/__init__.py +++ b/src/crewai/llms/__init__.py @@ -1 +1,12 @@ """LLM implementations for crewAI.""" + +from .base_llm import BaseLLM +from .oauth2_config import OAuth2Config, OAuth2ConfigLoader +from .oauth2_token_manager import OAuth2TokenManager + +__all__ = [ + "BaseLLM", + "OAuth2Config", + "OAuth2ConfigLoader", + "OAuth2TokenManager" +] diff --git a/src/crewai/llms/oauth2_config.py b/src/crewai/llms/oauth2_config.py new file mode 100644 index 000000000..5373fcefc --- /dev/null +++ b/src/crewai/llms/oauth2_config.py @@ -0,0 +1,38 @@ +from pathlib import Path +from typing import Dict, List, Optional +import json +from pydantic import BaseModel, Field + + +class OAuth2Config(BaseModel): + client_id: str = Field(description="OAuth2 client ID") + client_secret: str = Field(description="OAuth2 client secret") + token_url: str = Field(description="OAuth2 token endpoint URL") + scope: Optional[str] = Field(default=None, description="OAuth2 scope") + provider_name: str = Field(description="Custom provider name") + refresh_token: Optional[str] = Field(default=None, description="OAuth2 refresh token") + + +class OAuth2ConfigLoader: + def __init__(self, config_path: Optional[str] = None): + self.config_path = Path(config_path) if config_path else Path("litellm_config.json") + + def load_config(self) -> Dict[str, OAuth2Config]: + """Load OAuth2 configurations from litellm_config.json""" + if not self.config_path.exists(): + return {} + + try: + with open(self.config_path, 'r') as f: + data = json.load(f) + + oauth2_configs = {} + for provider_name, config_data in data.get("oauth2_providers", {}).items(): + oauth2_configs[provider_name] = OAuth2Config( + provider_name=provider_name, + **config_data + ) + + return oauth2_configs + except (json.JSONDecodeError, KeyError, ValueError) as e: + raise ValueError(f"Invalid OAuth2 configuration in {self.config_path}: {e}") diff --git a/src/crewai/llms/oauth2_token_manager.py b/src/crewai/llms/oauth2_token_manager.py new file mode 100644 index 000000000..af14117f2 --- /dev/null +++ b/src/crewai/llms/oauth2_token_manager.py @@ -0,0 +1,64 @@ +import time +import requests +from typing import Dict, Optional +from .oauth2_config import OAuth2Config + + +class OAuth2TokenManager: + def __init__(self): + self._tokens: Dict[str, Dict[str, any]] = {} + + def get_access_token(self, config: OAuth2Config) -> str: + """Get valid access token for the provider, refreshing if necessary""" + provider_name = config.provider_name + + if provider_name in self._tokens: + token_data = self._tokens[provider_name] + if self._is_token_valid(token_data): + return token_data["access_token"] + + return self._acquire_new_token(config) + + def _is_token_valid(self, token_data: Dict[str, any]) -> bool: + """Check if token is still valid (not expired)""" + if "expires_at" not in token_data: + return False + + return time.time() < (token_data["expires_at"] - 60) + + def _acquire_new_token(self, config: OAuth2Config) -> str: + """Acquire new access token using client credentials flow""" + payload = { + "grant_type": "client_credentials", + "client_id": config.client_id, + "client_secret": config.client_secret, + } + + if config.scope: + payload["scope"] = config.scope + + try: + response = requests.post( + config.token_url, + data=payload, + timeout=30, + headers={"Content-Type": "application/x-www-form-urlencoded"} + ) + response.raise_for_status() + + token_data = response.json() + access_token = token_data["access_token"] + + expires_in = token_data.get("expires_in", 3600) + self._tokens[config.provider_name] = { + "access_token": access_token, + "expires_at": time.time() + expires_in, + "token_type": token_data.get("token_type", "Bearer") + } + + return access_token + + except requests.RequestException as e: + raise RuntimeError(f"Failed to acquire OAuth2 token for {config.provider_name}: {e}") + except KeyError as e: + raise RuntimeError(f"Invalid token response from {config.provider_name}: missing {e}") diff --git a/tests/test_oauth2_llm.py b/tests/test_oauth2_llm.py new file mode 100644 index 000000000..e60f601a5 --- /dev/null +++ b/tests/test_oauth2_llm.py @@ -0,0 +1,285 @@ +import json +import pytest +import requests +import tempfile +import time +from pathlib import Path +from unittest.mock import Mock, patch, MagicMock +from crewai.llm import LLM +from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader +from crewai.llms.oauth2_token_manager import OAuth2TokenManager + + +class TestOAuth2Config: + def test_oauth2_config_creation(self): + config = OAuth2Config( + client_id="test_client", + client_secret="test_secret", + token_url="https://example.com/token", + provider_name="test_provider", + scope="read write" + ) + assert config.client_id == "test_client" + assert config.provider_name == "test_provider" + assert config.scope == "read write" + + def test_oauth2_config_loader_missing_file(self): + loader = OAuth2ConfigLoader("nonexistent.json") + configs = loader.load_config() + assert configs == {} + + def test_oauth2_config_loader_valid_config(self): + config_data = { + "oauth2_providers": { + "custom_provider": { + "client_id": "test_client", + "client_secret": "test_secret", + "token_url": "https://example.com/token", + "scope": "read" + } + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config_data, f) + config_path = f.name + + try: + loader = OAuth2ConfigLoader(config_path) + configs = loader.load_config() + + assert "custom_provider" in configs + config = configs["custom_provider"] + assert config.client_id == "test_client" + assert config.provider_name == "custom_provider" + finally: + Path(config_path).unlink() + + def test_oauth2_config_loader_invalid_json(self): + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + f.write("invalid json") + config_path = f.name + + try: + loader = OAuth2ConfigLoader(config_path) + with pytest.raises(ValueError, match="Invalid OAuth2 configuration"): + loader.load_config() + finally: + Path(config_path).unlink() + + +class TestOAuth2TokenManager: + def test_token_acquisition_success(self): + config = OAuth2Config( + client_id="test_client", + client_secret="test_secret", + token_url="https://example.com/token", + provider_name="test_provider" + ) + + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "test_token_123", + "token_type": "Bearer", + "expires_in": 3600 + } + mock_response.raise_for_status.return_value = None + + with patch('requests.post', return_value=mock_response) as mock_post: + manager = OAuth2TokenManager() + token = manager.get_access_token(config) + + assert token == "test_token_123" + mock_post.assert_called_once() + + def test_token_caching(self): + config = OAuth2Config( + client_id="test_client", + client_secret="test_secret", + token_url="https://example.com/token", + provider_name="test_provider" + ) + + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "test_token_123", + "token_type": "Bearer", + "expires_in": 3600 + } + mock_response.raise_for_status.return_value = None + + with patch('requests.post', return_value=mock_response) as mock_post: + manager = OAuth2TokenManager() + + token1 = manager.get_access_token(config) + assert token1 == "test_token_123" + assert mock_post.call_count == 1 + + token2 = manager.get_access_token(config) + assert token2 == "test_token_123" + assert mock_post.call_count == 1 + + def test_token_refresh_on_expiry(self): + config = OAuth2Config( + client_id="test_client", + client_secret="test_secret", + token_url="https://example.com/token", + provider_name="test_provider" + ) + + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "new_token_456", + "token_type": "Bearer", + "expires_in": 3600 + } + mock_response.raise_for_status.return_value = None + + with patch('requests.post', return_value=mock_response): + manager = OAuth2TokenManager() + + manager._tokens["test_provider"] = { + "access_token": "old_token", + "expires_at": time.time() - 100 + } + + token = manager.get_access_token(config) + assert token == "new_token_456" + + def test_token_acquisition_failure(self): + config = OAuth2Config( + client_id="test_client", + client_secret="test_secret", + token_url="https://example.com/token", + provider_name="test_provider" + ) + + with patch('requests.post', side_effect=requests.RequestException("Network error")): + manager = OAuth2TokenManager() + + with pytest.raises(RuntimeError, match="Failed to acquire OAuth2 token"): + manager.get_access_token(config) + + def test_invalid_token_response(self): + config = OAuth2Config( + client_id="test_client", + client_secret="test_secret", + token_url="https://example.com/token", + provider_name="test_provider" + ) + + mock_response = Mock() + mock_response.json.return_value = {"invalid": "response"} + mock_response.raise_for_status.return_value = None + + with patch('requests.post', return_value=mock_response): + manager = OAuth2TokenManager() + + with pytest.raises(RuntimeError, match="Invalid token response"): + manager.get_access_token(config) + + +class TestLLMOAuth2Integration: + def test_llm_with_oauth2_config(self): + config_data = { + "oauth2_providers": { + "custom": { + "client_id": "test_client", + "client_secret": "test_secret", + "token_url": "https://example.com/token" + } + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config_data, f) + config_path = f.name + + try: + llm = LLM( + model="custom/test-model", + oauth2_config_path=config_path + ) + + assert "custom" in llm.oauth2_configs + assert llm.oauth2_configs["custom"].client_id == "test_client" + finally: + Path(config_path).unlink() + + def test_llm_without_oauth2_config(self): + llm = LLM(model="openai/gpt-3.5-turbo") + assert llm.oauth2_configs == {} + + @patch('crewai.llm.litellm.completion') + def test_llm_oauth2_token_injection(self, mock_completion): + config_data = { + "oauth2_providers": { + "custom": { + "client_id": "test_client", + "client_secret": "test_secret", + "token_url": "https://example.com/token" + } + } + } + + mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))]) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config_data, f) + config_path = f.name + + try: + with patch.object(OAuth2TokenManager, 'get_access_token', return_value="oauth_token_123"): + llm = LLM( + model="custom/test-model", + oauth2_config_path=config_path + ) + + llm.call("test message") + + call_args = mock_completion.call_args + assert call_args[1]['api_key'] == "oauth_token_123" + finally: + Path(config_path).unlink() + + @patch('crewai.llm.litellm.completion') + def test_llm_oauth2_authentication_failure(self, mock_completion): + config_data = { + "oauth2_providers": { + "custom": { + "client_id": "test_client", + "client_secret": "test_secret", + "token_url": "https://example.com/token" + } + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config_data, f) + config_path = f.name + + try: + with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=RuntimeError("Auth failed")): + llm = LLM( + model="custom/test-model", + oauth2_config_path=config_path + ) + + with pytest.raises(RuntimeError, match="Auth failed"): + llm.call("test message") + finally: + Path(config_path).unlink() + + @patch('crewai.llm.litellm.completion') + def test_llm_non_oauth2_provider_unchanged(self, mock_completion): + mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))]) + + llm = LLM( + model="openai/gpt-3.5-turbo", + api_key="original_key" + ) + + llm.call("test message") + + call_args = mock_completion.call_args + assert call_args[1]['api_key'] == "original_key"