Compare commits

..

3 Commits

Author SHA1 Message Date
Devin AI
4379ad26d1 Implement comprehensive OAuth2 improvements based on code review feedback
- Add custom OAuth2 error classes (OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError)
- Implement URL and scope validation in OAuth2Config using pydantic field_validator
- Add retry mechanism with exponential backoff (3 attempts, 1s/2s/4s delays) to OAuth2TokenManager
- Implement thread safety using threading.Lock for concurrent token access
- Add provider validation in LLM integration with _validate_oauth2_provider method
- Enhance testing with validation, retry, thread safety, and error class tests
- Improve documentation with comprehensive error handling examples and feature descriptions
- Maintain backward compatibility while significantly improving robustness and security

Co-Authored-By: João <joao@crewai.com>
2025-07-07 18:10:30 +00:00
Devin AI
03ee4c59eb Fix lint and type-checker issues in OAuth2 implementation
- Remove unused imports (List, Optional, MagicMock)
- Fix type annotation: change 'any' to 'Any' in oauth2_token_manager.py
- All OAuth2 tests still pass after fixes

Co-Authored-By: João <joao@crewai.com>
2025-07-07 18:01:35 +00:00
Devin AI
ac1080afd8 Implement OAuth2 authentication support for custom LiteLLM providers
- Add OAuth2Config and OAuth2ConfigLoader for litellm_config.json configuration
- Add OAuth2TokenManager for token acquisition, caching, and refresh
- Extend LLM class to support OAuth2 authentication with custom providers
- Add comprehensive tests covering OAuth2 flow and error handling
- Add documentation and usage examples
- Support Client Credentials OAuth2 flow for server-to-server authentication
- Maintain backward compatibility with existing LLM providers

Fixes #3114

Co-Authored-By: João <joao@crewai.com>
2025-07-07 17:57:40 +00:00
20 changed files with 941 additions and 633 deletions

View File

@@ -5,7 +5,3 @@ repos:
- id: ruff
args: ["--fix"]
- id: ruff-format
- repo: https://github.com/commitizen-tools/commitizen
rev: v3.13.0
hooks:
- id: commitizen

View File

@@ -94,7 +94,7 @@
"pages": [
"en/guides/advanced/customizing-prompts",
"en/guides/advanced/fingerprinting"
]
}
]
@@ -296,8 +296,7 @@
"en/enterprise/features/webhook-streaming",
"en/enterprise/features/traces",
"en/enterprise/features/hallucination-guardrail",
"en/enterprise/features/integrations",
"en/enterprise/features/agent-repositories"
"en/enterprise/features/integrations"
]
},
{
@@ -374,7 +373,7 @@
}
]
}
]
},
{
@@ -731,7 +730,7 @@
}
]
}
]
}
]
@@ -775,7 +774,7 @@
"destination": "/en/introduction"
},
{
"source": "/installation",
"source": "/installation",
"destination": "/en/installation"
},
{

View File

@@ -526,103 +526,6 @@ agent = Agent(
The context window management feature works automatically in the background. You don't need to call any special functions - just set `respect_context_window` to your preferred behavior and CrewAI handles the rest!
</Note>
## Direct Agent Interaction with `kickoff()`
Agents can be used directly without going through a task or crew workflow using the `kickoff()` method. This provides a simpler way to interact with an agent when you don't need the full crew orchestration capabilities.
### How `kickoff()` Works
The `kickoff()` method allows you to send messages directly to an agent and get a response, similar to how you would interact with an LLM but with all the agent's capabilities (tools, reasoning, etc.).
```python Code
from crewai import Agent
from crewai_tools import SerperDevTool
# Create an agent
researcher = Agent(
role="AI Technology Researcher",
goal="Research the latest AI developments",
tools=[SerperDevTool()],
verbose=True
)
# Use kickoff() to interact directly with the agent
result = researcher.kickoff("What are the latest developments in language models?")
# Access the raw response
print(result.raw)
```
### Parameters and Return Values
| Parameter | Type | Description |
| :---------------- | :---------------------------------- | :------------------------------------------------------------------------ |
| `messages` | `Union[str, List[Dict[str, str]]]` | Either a string query or a list of message dictionaries with role/content |
| `response_format` | `Optional[Type[Any]]` | Optional Pydantic model for structured output |
The method returns a `LiteAgentOutput` object with the following properties:
- `raw`: String containing the raw output text
- `pydantic`: Parsed Pydantic model (if a `response_format` was provided)
- `agent_role`: Role of the agent that produced the output
- `usage_metrics`: Token usage metrics for the execution
### Structured Output
You can get structured output by providing a Pydantic model as the `response_format`:
```python Code
from pydantic import BaseModel
from typing import List
class ResearchFindings(BaseModel):
main_points: List[str]
key_technologies: List[str]
future_predictions: str
# Get structured output
result = researcher.kickoff(
"Summarize the latest developments in AI for 2025",
response_format=ResearchFindings
)
# Access structured data
print(result.pydantic.main_points)
print(result.pydantic.future_predictions)
```
### Multiple Messages
You can also provide a conversation history as a list of message dictionaries:
```python Code
messages = [
{"role": "user", "content": "I need information about large language models"},
{"role": "assistant", "content": "I'd be happy to help with that! What specifically would you like to know?"},
{"role": "user", "content": "What are the latest developments in 2025?"}
]
result = researcher.kickoff(messages)
```
### Async Support
An asynchronous version is available via `kickoff_async()` with the same parameters:
```python Code
import asyncio
async def main():
result = await researcher.kickoff_async("What are the latest developments in AI?")
print(result.raw)
asyncio.run(main())
```
<Note>
The `kickoff()` method uses a `LiteAgent` internally, which provides a simpler execution flow while preserving all of the agent's configuration (role, goal, backstory, tools, etc.).
</Note>
## Important Considerations and Best Practices
### Security and Code Execution

View File

@@ -1,155 +0,0 @@
---
title: 'Agent Repositories'
description: 'Learn how to use Agent Repositories to share and reuse your agents across teams and projects'
icon: 'database'
---
Agent Repositories allow enterprise users to store, share, and reuse agent definitions across teams and projects. This feature enables organizations to maintain a centralized library of standardized agents, promoting consistency and reducing duplication of effort.
## Benefits of Agent Repositories
- **Standardization**: Maintain consistent agent definitions across your organization
- **Reusability**: Create an agent once and use it in multiple crews and projects
- **Governance**: Implement organization-wide policies for agent configurations
- **Collaboration**: Enable teams to share and build upon each other's work
## Using Agent Repositories
### Prerequisites
1. You must have an account at CrewAI, try the [free plan](https://app.crewai.com).
2. You need to be authenticated using the CrewAI CLI.
3. If you have more than one organization, make sure you are switched to the correct organization using the CLI command:
```bash
crewai org switch <org_id>
```
### Creating and Managing Agents in Repositories
To create and manage agents in repositories,Enterprise Dashboard.
### Loading Agents from Repositories
You can load agents from repositories in your code using the `from_repository` parameter:
```python
from crewai import Agent
# Create an agent by loading it from a repository
# The agent is loaded with all its predefined configurations
researcher = Agent(
from_repository="market-research-agent"
)
```
### Overriding Repository Settings
You can override specific settings from the repository by providing them in the configuration:
```python
researcher = Agent(
from_repository="market-research-agent",
goal="Research the latest trends in AI development", # Override the repository goal
verbose=True # Add a setting not in the repository
)
```
### Example: Creating a Crew with Repository Agents
```python
from crewai import Crew, Agent, Task
# Load agents from repositories
researcher = Agent(
from_repository="market-research-agent"
)
writer = Agent(
from_repository="content-writer-agent"
)
# Create tasks
research_task = Task(
description="Research the latest trends in AI",
agent=researcher
)
writing_task = Task(
description="Write a comprehensive report based on the research",
agent=writer
)
# Create the crew
crew = Crew(
agents=[researcher, writer],
tasks=[research_task, writing_task],
verbose=True
)
# Run the crew
result = crew.kickoff()
```
### Example: Using `kickoff()` with Repository Agents
You can also use repository agents directly with the `kickoff()` method for simpler interactions:
```python
from crewai import Agent
from pydantic import BaseModel
from typing import List
# Define a structured output format
class MarketAnalysis(BaseModel):
key_trends: List[str]
opportunities: List[str]
recommendation: str
# Load an agent from repository
analyst = Agent(
from_repository="market-analyst-agent",
verbose=True
)
# Get a free-form response
result = analyst.kickoff("Analyze the AI market in 2025")
print(result.raw) # Access the raw response
# Get structured output
structured_result = analyst.kickoff(
"Provide a structured analysis of the AI market in 2025",
response_format=MarketAnalysis
)
# Access structured data
print(f"Key Trends: {structured_result.pydantic.key_trends}")
print(f"Recommendation: {structured_result.pydantic.recommendation}")
```
## Best Practices
1. **Naming Convention**: Use clear, descriptive names for your repository agents
2. **Documentation**: Include comprehensive descriptions for each agent
3. **Tool Management**: Ensure that tools referenced by repository agents are available in your environment
4. **Access Control**: Manage permissions to ensure only authorized team members can modify repository agents
## Organization Management
To switch between organizations or see your current organization, use the CrewAI CLI:
```bash
# View current organization
crewai org current
# Switch to a different organization
crewai org switch <org_id>
# List all available organizations
crewai org list
```
<Note>
When loading agents from repositories, you must be authenticated and switched to the correct organization. If you receive errors, check your authentication status and organization settings using the CLI commands above.
</Note>

View File

@@ -0,0 +1,216 @@
# OAuth2 LLM Providers
CrewAI supports OAuth2 authentication for custom LiteLLM providers through configuration files.
## Features
- **Automatic Token Management**: Handles OAuth2 token acquisition and refresh automatically
- **Multiple Provider Support**: Configure multiple OAuth2 providers in a single configuration file
- **Secure Credential Storage**: Keep OAuth2 credentials separate from your code
- **Seamless Integration**: Works with existing LiteLLM provider configurations
- **Error Handling**: Comprehensive error handling for authentication failures
- **Retry Mechanism**: Automatic retry with exponential backoff for token acquisition
- **Thread Safety**: Concurrent access protection for token caching
- **Configuration Validation**: Automatic validation of URLs and scope formats
## Error Handling
The OAuth2 implementation provides specific error classes for different failure scenarios:
### OAuth2ConfigurationError
Raised when there are issues with the OAuth2 configuration:
- Invalid configuration file format
- Missing required fields
- Unknown OAuth2 providers
### OAuth2AuthenticationError
Raised when OAuth2 authentication fails:
- Network errors during token acquisition
- Invalid credentials
- Token endpoint errors
### OAuth2ValidationError
Raised when configuration validation fails:
- Invalid URL formats
- Malformed scope values
### Example Error Handling
```python
from crewai import LLM
from crewai.llms.oauth2_errors import OAuth2ConfigurationError, OAuth2AuthenticationError
try:
llm = LLM(
model="my_custom_provider/my-model",
oauth2_config_path="./litellm_config.json"
)
result = llm.call("test message")
except OAuth2ConfigurationError as e:
print(f"Configuration error: {e}")
except OAuth2AuthenticationError as e:
print(f"Authentication failed: {e}")
if e.original_error:
print(f"Original error: {e.original_error}")
```
## Configuration Validation
The OAuth2 configuration is automatically validated:
- **URL Validation**: `token_url` must be a valid HTTP/HTTPS URL
- **Scope Validation**: `scope` cannot contain empty values when split by spaces
- **Required Fields**: `client_id`, `client_secret`, `token_url`, and `provider_name` are required
## Retry Mechanism
Token acquisition includes automatic retry with exponential backoff:
- 3 retry attempts by default
- Exponential backoff: 1s, 2s, 4s delays
- Detailed logging of retry attempts
- Thread-safe token caching
## Configuration
Create a `litellm_config.json` file in your project directory:
```json
{
"oauth2_providers": {
"my_custom_provider": {
"client_id": "your_client_id",
"client_secret": "your_client_secret",
"token_url": "https://your-provider.com/oauth/token",
"scope": "llm.read llm.write"
},
"another_provider": {
"client_id": "another_client_id",
"client_secret": "another_client_secret",
"token_url": "https://another-provider.com/token"
}
}
}
```
## Usage
```python
from crewai import LLM
# Initialize LLM with OAuth2 support
llm = LLM(
model="my_custom_provider/my-model",
oauth2_config_path="./litellm_config.json" # Optional, defaults to ./litellm_config.json
)
# Use in CrewAI
from crewai import Agent, Task, Crew
agent = Agent(
role="Data Analyst",
goal="Analyze data trends",
backstory="Expert in data analysis",
llm=llm
)
task = Task(
description="Analyze the latest sales data",
agent=agent
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
```
## Environment Variables
You can also use environment variables in your configuration:
```json
{
"oauth2_providers": {
"my_provider": {
"client_id": "os.environ/MY_CLIENT_ID",
"client_secret": "os.environ/MY_CLIENT_SECRET",
"token_url": "https://my-provider.com/token"
}
}
}
```
## Supported OAuth2 Flow
Currently supports the **Client Credentials** OAuth2 flow, which is suitable for server-to-server authentication.
## Token Management
- Tokens are automatically cached and refreshed when they expire
- A 60-second buffer is used before token expiration to ensure reliability
- Failed token acquisition will raise a `RuntimeError` with details
## Configuration Schema
The `litellm_config.json` file should follow this schema:
```json
{
"oauth2_providers": {
"<provider_name>": {
"client_id": "string (required)",
"client_secret": "string (required)",
"token_url": "string (required)",
"scope": "string (optional)",
"refresh_token": "string (optional)"
}
}
}
```
## Legacy Error Handling
For backward compatibility, the following error types are still supported:
- OAuth2 authentication failures raise `OAuth2AuthenticationError` (previously `RuntimeError`)
- Invalid configuration files raise `OAuth2ConfigurationError` (previously `ValueError`)
- Network errors during token acquisition are wrapped in `OAuth2AuthenticationError`
## Examples
### Basic OAuth2 Provider
```python
from crewai import LLM
llm = LLM(
model="my_provider/gpt-4",
oauth2_config_path="./config.json"
)
response = llm.call("Hello, world!")
print(response)
```
### Multiple Providers
```json
{
"oauth2_providers": {
"provider_a": {
"client_id": "client_a",
"client_secret": "secret_a",
"token_url": "https://provider-a.com/token"
},
"provider_b": {
"client_id": "client_b",
"client_secret": "secret_b",
"token_url": "https://provider-b.com/oauth/token",
"scope": "read write"
}
}
}
```
```python
# Use different providers
llm_a = LLM(model="provider_a/model-1", oauth2_config_path="./config.json")
llm_b = LLM(model="provider_b/model-2", oauth2_config_path="./config.json")
```

View File

@@ -0,0 +1,64 @@
"""
Example demonstrating OAuth2 authentication with custom LLM providers in CrewAI.
This example shows how to configure and use OAuth2-authenticated LLM providers.
"""
import json
from pathlib import Path
from crewai import Agent, Task, Crew, LLM
def create_example_config():
"""Create an example OAuth2 configuration file."""
config = {
"oauth2_providers": {
"my_custom_provider": {
"client_id": "your_client_id_here",
"client_secret": "your_client_secret_here",
"token_url": "https://your-provider.com/oauth/token",
"scope": "llm.read llm.write"
}
}
}
config_path = Path("example_oauth2_config.json")
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
print(f"Created example config at {config_path}")
return config_path
def main():
config_path = create_example_config()
try:
llm = LLM(
model="my_custom_provider/my-model",
oauth2_config_path=str(config_path)
)
agent = Agent(
role="Research Assistant",
goal="Provide helpful research insights",
backstory="An AI assistant specialized in research and analysis",
llm=llm
)
task = Task(
description="Research the benefits of OAuth2 authentication in AI systems",
agent=agent,
expected_output="A comprehensive summary of OAuth2 benefits"
)
crew = Crew(agents=[agent], tasks=[task])
print("Running crew with OAuth2-authenticated LLM...")
result = crew.kickoff()
print(f"Result: {result}")
finally:
if config_path.exists():
config_path.unlink()
if __name__ == "__main__":
main()

View File

@@ -18,11 +18,6 @@ from typing import (
cast,
)
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from crewai.utilities.crew.models import CrewContext
from pydantic import (
UUID4,
BaseModel,
@@ -621,11 +616,6 @@ class Crew(FlowTrackable, BaseModel):
self,
inputs: Optional[Dict[str, Any]] = None,
) -> CrewOutput:
ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key)
)
token = attach(ctx)
try:
for before_callback in self.before_kickoff_callbacks:
if inputs is None:
@@ -686,8 +676,6 @@ class Crew(FlowTrackable, BaseModel):
CrewKickoffFailedEvent(error=str(e), crew_name=self.name or "crew"),
)
raise
finally:
detach(token)
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
"""Executes the Crew's workflow for each input in the list and aggregates results."""

View File

@@ -52,6 +52,9 @@ import io
from typing import TextIO
from crewai.llms.base_llm import BaseLLM
from crewai.llms.oauth2_config import OAuth2ConfigLoader
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
from crewai.llms.oauth2_errors import OAuth2ConfigurationError
from crewai.utilities.events import crewai_event_bus
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
@@ -311,6 +314,7 @@ class LLM(BaseLLM):
callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
stream: bool = False,
oauth2_config_path: Optional[str] = None,
**kwargs,
):
self.model = model
@@ -338,6 +342,10 @@ class LLM(BaseLLM):
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self.oauth2_config_loader = OAuth2ConfigLoader(oauth2_config_path)
self.oauth2_token_manager = OAuth2TokenManager()
self.oauth2_configs = self.oauth2_config_loader.load_config()
litellm.drop_params = True
# Normalize self.stop to always be a List[str]
@@ -384,7 +392,21 @@ class LLM(BaseLLM):
messages = [{"role": "user", "content": messages}]
formatted_messages = self._format_messages_for_provider(messages)
# --- 2) Prepare the parameters for the completion call
api_key = self.api_key
provider = self._get_custom_llm_provider()
if provider and provider in self.oauth2_configs:
oauth2_config = self.oauth2_configs[provider]
try:
self._validate_oauth2_provider(provider)
access_token = self.oauth2_token_manager.get_access_token(oauth2_config)
api_key = access_token
logging.debug(f"Using OAuth2 authentication for provider {provider}")
except Exception as e:
logging.error(f"OAuth2 authentication failed for provider {provider}: {e}")
raise
# --- 3) Prepare the parameters for the completion call
params = {
"model": self.model,
"messages": formatted_messages,
@@ -404,7 +426,7 @@ class LLM(BaseLLM):
"api_base": self.api_base,
"base_url": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"api_key": api_key,
"stream": self.stream,
"tools": tools,
"reasoning_effort": self.reasoning_effort,
@@ -1076,6 +1098,12 @@ class LLM(BaseLLM):
return self.model.split("/")[0]
return None
def _validate_oauth2_provider(self, provider: str) -> bool:
"""Validate that OAuth2 provider exists in configuration"""
if provider not in self.oauth2_configs:
raise OAuth2ConfigurationError(f"Unknown OAuth2 provider: {provider}")
return True
def _validate_call_params(self) -> None:
"""
Validate parameters before making a call. Currently this only checks if

View File

@@ -1 +1,17 @@
"""LLM implementations for crewAI."""
from .base_llm import BaseLLM
from .oauth2_config import OAuth2Config, OAuth2ConfigLoader
from .oauth2_token_manager import OAuth2TokenManager
from .oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError
__all__ = [
"BaseLLM",
"OAuth2Config",
"OAuth2ConfigLoader",
"OAuth2TokenManager",
"OAuth2Error",
"OAuth2ConfigurationError",
"OAuth2AuthenticationError",
"OAuth2ValidationError"
]

View File

@@ -0,0 +1,65 @@
from pathlib import Path
from typing import Dict, Optional
import json
import re
from pydantic import BaseModel, Field, field_validator
from .oauth2_errors import OAuth2ConfigurationError, OAuth2ValidationError
class OAuth2Config(BaseModel):
client_id: str = Field(description="OAuth2 client ID")
client_secret: str = Field(description="OAuth2 client secret")
token_url: str = Field(description="OAuth2 token endpoint URL")
scope: Optional[str] = Field(default=None, description="OAuth2 scope")
provider_name: str = Field(description="Custom provider name")
refresh_token: Optional[str] = Field(default=None, description="OAuth2 refresh token")
@field_validator('token_url')
@classmethod
def validate_token_url(cls, v: str) -> str:
"""Validate that token_url is a valid HTTP/HTTPS URL."""
url_pattern = re.compile(
r'^https?://' # http:// or https://
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain...
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
if not url_pattern.match(v):
raise OAuth2ValidationError(f"Invalid token URL format: {v}")
return v
@field_validator('scope')
@classmethod
def validate_scope(cls, v: Optional[str]) -> Optional[str]:
"""Validate OAuth2 scope format."""
if v:
if ' ' in v:
raise OAuth2ValidationError("Invalid scope format: scope cannot contain empty values")
return v
class OAuth2ConfigLoader:
def __init__(self, config_path: Optional[str] = None):
self.config_path = Path(config_path) if config_path else Path("litellm_config.json")
def load_config(self) -> Dict[str, OAuth2Config]:
"""Load OAuth2 configurations from litellm_config.json"""
if not self.config_path.exists():
return {}
try:
with open(self.config_path, 'r') as f:
data = json.load(f)
oauth2_configs = {}
for provider_name, config_data in data.get("oauth2_providers", {}).items():
oauth2_configs[provider_name] = OAuth2Config(
provider_name=provider_name,
**config_data
)
return oauth2_configs
except (json.JSONDecodeError, KeyError, ValueError) as e:
raise OAuth2ConfigurationError(f"Invalid OAuth2 configuration in {self.config_path}: {e}")

View File

@@ -0,0 +1,26 @@
"""OAuth2 error classes for CrewAI."""
from typing import Optional
class OAuth2Error(Exception):
"""Base exception class for OAuth2 operation errors."""
def __init__(self, message: str, original_error: Optional[Exception] = None):
super().__init__(message)
self.original_error = original_error
class OAuth2ConfigurationError(OAuth2Error):
"""Exception raised for OAuth2 configuration errors."""
pass
class OAuth2AuthenticationError(OAuth2Error):
"""Exception raised for OAuth2 authentication failures."""
pass
class OAuth2ValidationError(OAuth2Error):
"""Exception raised for OAuth2 validation errors."""
pass

View File

@@ -0,0 +1,93 @@
import time
import logging
import requests
from threading import Lock
from typing import Dict, Any
from .oauth2_config import OAuth2Config
from .oauth2_errors import OAuth2AuthenticationError
class OAuth2TokenManager:
def __init__(self):
self._tokens: Dict[str, Dict[str, Any]] = {}
self._lock = Lock()
def get_access_token(self, config: OAuth2Config) -> str:
"""Get valid access token for the provider, refreshing if necessary"""
with self._lock:
return self._get_access_token_internal(config)
def _get_access_token_internal(self, config: OAuth2Config) -> str:
"""Internal method to get access token (called within lock)"""
provider_name = config.provider_name
if provider_name in self._tokens:
token_data = self._tokens[provider_name]
if self._is_token_valid(token_data):
logging.debug(f"Using cached OAuth2 token for provider {provider_name}")
return token_data["access_token"]
logging.info(f"Acquiring new OAuth2 token for provider {provider_name}")
return self._acquire_new_token(config)
def _is_token_valid(self, token_data: Dict[str, Any]) -> bool:
"""Check if token is still valid (not expired)"""
if "expires_at" not in token_data:
return False
return time.time() < (token_data["expires_at"] - 60)
def _acquire_new_token(self, config: OAuth2Config, retry_count: int = 3) -> str:
"""Acquire new access token using client credentials flow with retry logic"""
for attempt in range(retry_count):
try:
return self._perform_token_request(config)
except requests.RequestException as e:
if attempt == retry_count - 1:
raise OAuth2AuthenticationError(
f"Failed to acquire OAuth2 token for {config.provider_name} after {retry_count} attempts: {e}",
original_error=e
)
wait_time = 2 ** attempt
logging.warning(f"OAuth2 token request failed for {config.provider_name}, retrying in {wait_time}s (attempt {attempt + 1}/{retry_count}): {e}")
time.sleep(wait_time)
raise OAuth2AuthenticationError(f"Unexpected error in token acquisition for {config.provider_name}")
def _perform_token_request(self, config: OAuth2Config) -> str:
"""Perform the actual token request"""
payload = {
"grant_type": "client_credentials",
"client_id": config.client_id,
"client_secret": config.client_secret,
}
if config.scope:
payload["scope"] = config.scope
try:
response = requests.post(
config.token_url,
data=payload,
timeout=30,
headers={"Content-Type": "application/x-www-form-urlencoded"}
)
response.raise_for_status()
token_data = response.json()
access_token = token_data["access_token"]
expires_in = token_data.get("expires_in", 3600)
self._tokens[config.provider_name] = {
"access_token": access_token,
"expires_at": time.time() + expires_in,
"token_type": token_data.get("token_type", "Bearer")
}
logging.info(f"Successfully acquired OAuth2 token for {config.provider_name}, expires in {expires_in}s")
return access_token
except requests.RequestException:
raise
except KeyError as e:
raise OAuth2AuthenticationError(f"Invalid token response from {config.provider_name}: missing {e}")

View File

@@ -1 +0,0 @@
"""Crew-specific utilities."""

View File

@@ -1,16 +0,0 @@
"""Context management utilities for tracking crew and task execution context using OpenTelemetry baggage."""
from typing import Optional
from opentelemetry import baggage
from crewai.utilities.crew.models import CrewContext
def get_crew_context() -> Optional[CrewContext]:
"""Get the current crew context from OpenTelemetry baggage.
Returns:
CrewContext instance containing crew context information, or None if no context is set
"""
return baggage.get_baggage("crew_context")

View File

@@ -1,16 +0,0 @@
"""Models for crew-related data structures."""
from typing import Optional
from pydantic import BaseModel, Field
class CrewContext(BaseModel):
"""Model representing crew context information."""
id: Optional[str] = Field(
default=None, description="Unique identifier for the crew"
)
key: Optional[str] = Field(
default=None, description="Optional crew key/name for identification"
)

View File

@@ -1,4 +1,3 @@
from inspect import getsource
from typing import Any, Callable, Optional, Union
from crewai.utilities.events.base_events import BaseEvent
@@ -17,26 +16,23 @@ class LLMGuardrailStartedEvent(BaseEvent):
retry_count: int
def __init__(self, **data):
from inspect import getsource
from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
super().__init__(**data)
if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)):
if isinstance(self.guardrail, LLMGuardrail) or isinstance(
self.guardrail, HallucinationGuardrail
):
self.guardrail = self.guardrail.description.strip()
elif isinstance(self.guardrail, Callable):
self.guardrail = getsource(self.guardrail).strip()
class LLMGuardrailCompletedEvent(BaseEvent):
"""Event emitted when a guardrail task completes
Attributes:
success: Whether the guardrail validation passed
result: The validation result
error: Error message if validation failed
retry_count: The number of times the guardrail has been retried
"""
"""Event emitted when a guardrail task completes"""
type: str = "llm_guardrail_completed"
success: bool

View File

@@ -1,226 +0,0 @@
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, Callable
from unittest.mock import patch
import pytest
from crewai import Agent, Crew, Task
from crewai.utilities.crew.crew_context import get_crew_context
@pytest.fixture
def simple_agent_factory():
def create_agent(name: str) -> Agent:
return Agent(
role=f"{name} Agent",
goal=f"Complete {name} task",
backstory=f"I am agent for {name}",
)
return create_agent
@pytest.fixture
def simple_task_factory():
def create_task(name: str, callback: Callable = None) -> Task:
return Task(
description=f"Task for {name}", expected_output="Done", callback=callback
)
return create_task
@pytest.fixture
def crew_factory(simple_agent_factory, simple_task_factory):
def create_crew(name: str, task_callback: Callable = None) -> Crew:
agent = simple_agent_factory(name)
task = simple_task_factory(name, callback=task_callback)
task.agent = agent
return Crew(agents=[agent], tasks=[task], verbose=False)
return create_crew
class TestCrewThreadSafety:
@patch("crewai.Agent.execute_task")
def test_parallel_crews_thread_safety(self, mock_execute_task, crew_factory):
mock_execute_task.return_value = "Task completed"
num_crews = 5
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
results = {"crew_id": crew_id, "contexts": []}
def check_context_task(output):
context = get_crew_context()
results["contexts"].append(
{
"stage": "task_callback",
"crew_id": context.id if context else None,
"crew_key": context.key if context else None,
"thread": threading.current_thread().name,
}
)
return output
context_before = get_crew_context()
results["contexts"].append(
{
"stage": "before_kickoff",
"crew_id": context_before.id if context_before else None,
"thread": threading.current_thread().name,
}
)
crew = crew_factory(crew_id, task_callback=check_context_task)
output = crew.kickoff()
context_after = get_crew_context()
results["contexts"].append(
{
"stage": "after_kickoff",
"crew_id": context_after.id if context_after else None,
"thread": threading.current_thread().name,
}
)
results["crew_uuid"] = str(crew.id)
results["output"] = output.raw
return results
with ThreadPoolExecutor(max_workers=num_crews) as executor:
futures = []
for i in range(num_crews):
future = executor.submit(run_crew_with_context_check, f"crew_{i}")
futures.append(future)
results = [f.result() for f in futures]
for result in results:
crew_uuid = result["crew_uuid"]
before_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
)
assert (
before_ctx["crew_id"] is None
), f"Context should be None before kickoff for {result['crew_id']}"
task_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
)
assert (
task_ctx["crew_id"] == crew_uuid
), f"Context mismatch during task for {result['crew_id']}"
after_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
)
assert (
after_ctx["crew_id"] is None
), f"Context should be None after kickoff for {result['crew_id']}"
thread_name = before_ctx["thread"]
assert (
"ThreadPoolExecutor" in thread_name
), f"Should run in thread pool for {result['crew_id']}"
@pytest.mark.asyncio
@patch("crewai.Agent.execute_task")
async def test_async_crews_thread_safety(self, mock_execute_task, crew_factory):
mock_execute_task.return_value = "Task completed"
num_crews = 5
async def run_crew_async(crew_id: str) -> Dict[str, Any]:
task_context = {"crew_id": crew_id, "context": None}
def capture_context(output):
ctx = get_crew_context()
task_context["context"] = {
"crew_id": ctx.id if ctx else None,
"crew_key": ctx.key if ctx else None,
}
return output
crew = crew_factory(crew_id, task_callback=capture_context)
output = await crew.kickoff_async()
return {
"crew_id": crew_id,
"crew_uuid": str(crew.id),
"output": output.raw,
"task_context": task_context,
}
tasks = [run_crew_async(f"async_crew_{i}") for i in range(num_crews)]
results = await asyncio.gather(*tasks)
for result in results:
crew_uuid = result["crew_uuid"]
task_ctx = result["task_context"]["context"]
assert (
task_ctx is not None
), f"Context should exist during task for {result['crew_id']}"
assert (
task_ctx["crew_id"] == crew_uuid
), f"Context mismatch for {result['crew_id']}"
@patch("crewai.Agent.execute_task")
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
mock_execute_task.return_value = "Task completed"
contexts_captured = []
def capture_context(output):
ctx = get_crew_context()
contexts_captured.append(
{
"context_id": ctx.id if ctx else None,
"thread": threading.current_thread().name,
}
)
return output
crew = crew_factory("for_each_test", task_callback=capture_context)
inputs = [{"item": f"input_{i}"} for i in range(3)]
results = crew.kickoff_for_each(inputs=inputs)
assert len(results) == len(inputs)
assert len(contexts_captured) == len(inputs)
context_ids = [ctx["context_id"] for ctx in contexts_captured]
assert len(set(context_ids)) == len(
inputs
), "Each execution should have unique context"
@patch("crewai.Agent.execute_task")
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):
mock_execute_task.return_value = "Task completed"
contexts = []
def check_context(output):
ctx = get_crew_context()
contexts.append(
{
"context_id": ctx.id if ctx else None,
"context_key": ctx.key if ctx else None,
}
)
return output
def run_crew(name: str):
crew = crew_factory(name, task_callback=check_context)
crew.kickoff()
return str(crew.id)
crew1_id = run_crew("First")
crew2_id = run_crew("Second")
assert len(contexts) == 2
assert contexts[0]["context_id"] == crew1_id
assert contexts[1]["context_id"] == crew2_id
assert contexts[0]["context_id"] != contexts[1]["context_id"]

420
tests/test_oauth2_llm.py Normal file
View File

@@ -0,0 +1,420 @@
import json
import pytest
import requests
import tempfile
import time
import threading
from pathlib import Path
from unittest.mock import Mock, patch
from crewai.llm import LLM
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
from crewai.llms.oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError
class TestOAuth2Config:
def test_oauth2_config_creation(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider",
scope="read write"
)
assert config.client_id == "test_client"
assert config.provider_name == "test_provider"
assert config.scope == "read write"
def test_oauth2_config_loader_missing_file(self):
loader = OAuth2ConfigLoader("nonexistent.json")
configs = loader.load_config()
assert configs == {}
def test_oauth2_config_loader_valid_config(self):
config_data = {
"oauth2_providers": {
"custom_provider": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token",
"scope": "read"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
loader = OAuth2ConfigLoader(config_path)
configs = loader.load_config()
assert "custom_provider" in configs
config = configs["custom_provider"]
assert config.client_id == "test_client"
assert config.provider_name == "custom_provider"
finally:
Path(config_path).unlink()
def test_oauth2_config_loader_invalid_json(self):
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write("invalid json")
config_path = f.name
try:
loader = OAuth2ConfigLoader(config_path)
with pytest.raises(OAuth2ConfigurationError, match="Invalid OAuth2 configuration"):
loader.load_config()
finally:
Path(config_path).unlink()
class TestOAuth2TokenManager:
def test_token_acquisition_success(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "test_token_123",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response) as mock_post:
manager = OAuth2TokenManager()
token = manager.get_access_token(config)
assert token == "test_token_123"
mock_post.assert_called_once()
def test_token_caching(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "test_token_123",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response) as mock_post:
manager = OAuth2TokenManager()
token1 = manager.get_access_token(config)
assert token1 == "test_token_123"
assert mock_post.call_count == 1
token2 = manager.get_access_token(config)
assert token2 == "test_token_123"
assert mock_post.call_count == 1
def test_token_refresh_on_expiry(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "new_token_456",
"token_type": "Bearer",
"expires_in": 3600
}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
manager = OAuth2TokenManager()
manager._tokens["test_provider"] = {
"access_token": "old_token",
"expires_at": time.time() - 100
}
token = manager.get_access_token(config)
assert token == "new_token_456"
def test_token_acquisition_failure(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
with patch('requests.post', side_effect=requests.RequestException("Network error")):
with patch('time.sleep'):
manager = OAuth2TokenManager()
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token"):
manager.get_access_token(config)
def test_invalid_token_response(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {"invalid": "response"}
mock_response.raise_for_status.return_value = None
with patch('requests.post', return_value=mock_response):
manager = OAuth2TokenManager()
with pytest.raises(OAuth2AuthenticationError, match="Invalid token response"):
manager.get_access_token(config)
class TestLLMOAuth2Integration:
def test_llm_with_oauth2_config(self):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
assert "custom" in llm.oauth2_configs
assert llm.oauth2_configs["custom"].client_id == "test_client"
finally:
Path(config_path).unlink()
def test_llm_without_oauth2_config(self):
llm = LLM(model="openai/gpt-3.5-turbo")
assert llm.oauth2_configs == {}
@patch('crewai.llm.litellm.completion')
def test_llm_oauth2_token_injection(self, mock_completion):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
with patch.object(OAuth2TokenManager, 'get_access_token', return_value="oauth_token_123"):
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
llm.call("test message")
call_args = mock_completion.call_args
assert call_args[1]['api_key'] == "oauth_token_123"
finally:
Path(config_path).unlink()
@patch('crewai.llm.litellm.completion')
def test_llm_oauth2_authentication_failure(self, mock_completion):
config_data = {
"oauth2_providers": {
"custom": {
"client_id": "test_client",
"client_secret": "test_secret",
"token_url": "https://example.com/token"
}
}
}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(config_data, f)
config_path = f.name
try:
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=OAuth2AuthenticationError("Auth failed")):
llm = LLM(
model="custom/test-model",
oauth2_config_path=config_path
)
with pytest.raises(OAuth2AuthenticationError, match="Auth failed"):
llm.call("test message")
finally:
Path(config_path).unlink()
@patch('crewai.llm.litellm.completion')
def test_llm_non_oauth2_provider_unchanged(self, mock_completion):
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
llm = LLM(
model="openai/gpt-3.5-turbo",
api_key="original_key"
)
llm.call("test message")
call_args = mock_completion.call_args
assert call_args[1]['api_key'] == "original_key"
class TestOAuth2Validation:
def test_valid_url_validation(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
assert config.token_url == "https://example.com/token"
def test_invalid_url_validation(self):
with pytest.raises(OAuth2ValidationError, match="Invalid token URL format"):
OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="not-a-url",
provider_name="test_provider"
)
def test_valid_scope_validation(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider",
scope="read write"
)
assert config.scope == "read write"
def test_invalid_scope_validation(self):
with pytest.raises(OAuth2ValidationError, match="Invalid scope format"):
OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider",
scope="read write"
)
class TestOAuth2RetryMechanism:
def test_retry_mechanism_success_on_second_attempt(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
responses = [
requests.RequestException("Network error"),
Mock(json=lambda: {"access_token": "token_123", "expires_in": 3600}, raise_for_status=lambda: None)
]
with patch('requests.post', side_effect=responses):
with patch('time.sleep'):
manager = OAuth2TokenManager()
token = manager.get_access_token(config)
assert token == "token_123"
def test_retry_mechanism_exhausted(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
with patch('requests.post', side_effect=requests.RequestException("Network error")):
with patch('time.sleep'):
manager = OAuth2TokenManager()
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token.*after 3 attempts"):
manager.get_access_token(config)
class TestOAuth2ThreadSafety:
def test_concurrent_token_requests(self):
config = OAuth2Config(
client_id="test_client",
client_secret="test_secret",
token_url="https://example.com/token",
provider_name="test_provider"
)
mock_response = Mock()
mock_response.json.return_value = {"access_token": "token_123", "expires_in": 3600}
mock_response.raise_for_status.return_value = None
call_count = 0
def mock_post(*args, **kwargs):
nonlocal call_count
call_count += 1
time.sleep(0.1)
return mock_response
with patch('requests.post', side_effect=mock_post):
manager = OAuth2TokenManager()
results = []
def get_token():
token = manager.get_access_token(config)
results.append(token)
threads = [threading.Thread(target=get_token) for _ in range(5)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
assert all(token == "token_123" for token in results)
assert call_count == 1
class TestOAuth2ErrorClasses:
def test_oauth2_configuration_error(self):
error = OAuth2ConfigurationError("Config error")
assert str(error) == "Config error"
assert isinstance(error, OAuth2Error)
def test_oauth2_authentication_error_with_original(self):
original = requests.RequestException("Network error")
error = OAuth2AuthenticationError("Auth failed", original_error=original)
assert str(error) == "Auth failed"
assert error.original_error == original
def test_oauth2_validation_error(self):
error = OAuth2ValidationError("Validation failed")
assert str(error) == "Validation failed"
assert isinstance(error, OAuth2Error)

View File

@@ -1,88 +0,0 @@
import uuid
import pytest
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from crewai.utilities.crew.crew_context import get_crew_context
from crewai.utilities.crew.models import CrewContext
def test_crew_context_creation():
crew_id = str(uuid.uuid4())
context = CrewContext(id=crew_id, key="test-crew")
assert context.id == crew_id
assert context.key == "test-crew"
def test_get_crew_context_with_baggage():
crew_id = str(uuid.uuid4())
assert get_crew_context() is None
crew_ctx = CrewContext(id=crew_id, key="test-key")
ctx = baggage.set_baggage("crew_context", crew_ctx)
token = attach(ctx)
try:
context = get_crew_context()
assert context is not None
assert context.id == crew_id
assert context.key == "test-key"
finally:
detach(token)
assert get_crew_context() is None
def test_get_crew_context_empty():
assert get_crew_context() is None
def test_baggage_nested_contexts():
crew_id1 = str(uuid.uuid4())
crew_id2 = str(uuid.uuid4())
crew_ctx1 = CrewContext(id=crew_id1, key="outer")
ctx1 = baggage.set_baggage("crew_context", crew_ctx1)
token1 = attach(ctx1)
try:
outer_context = get_crew_context()
assert outer_context.id == crew_id1
assert outer_context.key == "outer"
crew_ctx2 = CrewContext(id=crew_id2, key="inner")
ctx2 = baggage.set_baggage("crew_context", crew_ctx2)
token2 = attach(ctx2)
try:
inner_context = get_crew_context()
assert inner_context.id == crew_id2
assert inner_context.key == "inner"
finally:
detach(token2)
restored_context = get_crew_context()
assert restored_context.id == crew_id1
assert restored_context.key == "outer"
finally:
detach(token1)
assert get_crew_context() is None
def test_baggage_exception_handling():
crew_id = str(uuid.uuid4())
crew_ctx = CrewContext(id=crew_id, key="test")
ctx = baggage.set_baggage("crew_context", crew_ctx)
token = attach(ctx)
with pytest.raises(ValueError):
try:
assert get_crew_context() is not None
raise ValueError("Test exception")
finally:
detach(token)
assert get_crew_context() is None