mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-01 04:08:30 +00:00
Compare commits
3 Commits
gl/chore/a
...
devin/1751
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4379ad26d1 | ||
|
|
03ee4c59eb | ||
|
|
ac1080afd8 |
@@ -5,7 +5,3 @@ repos:
|
||||
- id: ruff
|
||||
args: ["--fix"]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/commitizen-tools/commitizen
|
||||
rev: v3.13.0
|
||||
hooks:
|
||||
- id: commitizen
|
||||
|
||||
@@ -94,7 +94,7 @@
|
||||
"pages": [
|
||||
"en/guides/advanced/customizing-prompts",
|
||||
"en/guides/advanced/fingerprinting"
|
||||
|
||||
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -296,8 +296,7 @@
|
||||
"en/enterprise/features/webhook-streaming",
|
||||
"en/enterprise/features/traces",
|
||||
"en/enterprise/features/hallucination-guardrail",
|
||||
"en/enterprise/features/integrations",
|
||||
"en/enterprise/features/agent-repositories"
|
||||
"en/enterprise/features/integrations"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -374,7 +373,7 @@
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -731,7 +730,7 @@
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
]
|
||||
}
|
||||
]
|
||||
@@ -775,7 +774,7 @@
|
||||
"destination": "/en/introduction"
|
||||
},
|
||||
{
|
||||
"source": "/installation",
|
||||
"source": "/installation",
|
||||
"destination": "/en/installation"
|
||||
},
|
||||
{
|
||||
|
||||
@@ -526,103 +526,6 @@ agent = Agent(
|
||||
The context window management feature works automatically in the background. You don't need to call any special functions - just set `respect_context_window` to your preferred behavior and CrewAI handles the rest!
|
||||
</Note>
|
||||
|
||||
## Direct Agent Interaction with `kickoff()`
|
||||
|
||||
Agents can be used directly without going through a task or crew workflow using the `kickoff()` method. This provides a simpler way to interact with an agent when you don't need the full crew orchestration capabilities.
|
||||
|
||||
### How `kickoff()` Works
|
||||
|
||||
The `kickoff()` method allows you to send messages directly to an agent and get a response, similar to how you would interact with an LLM but with all the agent's capabilities (tools, reasoning, etc.).
|
||||
|
||||
```python Code
|
||||
from crewai import Agent
|
||||
from crewai_tools import SerperDevTool
|
||||
|
||||
# Create an agent
|
||||
researcher = Agent(
|
||||
role="AI Technology Researcher",
|
||||
goal="Research the latest AI developments",
|
||||
tools=[SerperDevTool()],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Use kickoff() to interact directly with the agent
|
||||
result = researcher.kickoff("What are the latest developments in language models?")
|
||||
|
||||
# Access the raw response
|
||||
print(result.raw)
|
||||
```
|
||||
|
||||
### Parameters and Return Values
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| :---------------- | :---------------------------------- | :------------------------------------------------------------------------ |
|
||||
| `messages` | `Union[str, List[Dict[str, str]]]` | Either a string query or a list of message dictionaries with role/content |
|
||||
| `response_format` | `Optional[Type[Any]]` | Optional Pydantic model for structured output |
|
||||
|
||||
The method returns a `LiteAgentOutput` object with the following properties:
|
||||
|
||||
- `raw`: String containing the raw output text
|
||||
- `pydantic`: Parsed Pydantic model (if a `response_format` was provided)
|
||||
- `agent_role`: Role of the agent that produced the output
|
||||
- `usage_metrics`: Token usage metrics for the execution
|
||||
|
||||
### Structured Output
|
||||
|
||||
You can get structured output by providing a Pydantic model as the `response_format`:
|
||||
|
||||
```python Code
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
class ResearchFindings(BaseModel):
|
||||
main_points: List[str]
|
||||
key_technologies: List[str]
|
||||
future_predictions: str
|
||||
|
||||
# Get structured output
|
||||
result = researcher.kickoff(
|
||||
"Summarize the latest developments in AI for 2025",
|
||||
response_format=ResearchFindings
|
||||
)
|
||||
|
||||
# Access structured data
|
||||
print(result.pydantic.main_points)
|
||||
print(result.pydantic.future_predictions)
|
||||
```
|
||||
|
||||
### Multiple Messages
|
||||
|
||||
You can also provide a conversation history as a list of message dictionaries:
|
||||
|
||||
```python Code
|
||||
messages = [
|
||||
{"role": "user", "content": "I need information about large language models"},
|
||||
{"role": "assistant", "content": "I'd be happy to help with that! What specifically would you like to know?"},
|
||||
{"role": "user", "content": "What are the latest developments in 2025?"}
|
||||
]
|
||||
|
||||
result = researcher.kickoff(messages)
|
||||
```
|
||||
|
||||
### Async Support
|
||||
|
||||
An asynchronous version is available via `kickoff_async()` with the same parameters:
|
||||
|
||||
```python Code
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
result = await researcher.kickoff_async("What are the latest developments in AI?")
|
||||
print(result.raw)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
<Note>
|
||||
The `kickoff()` method uses a `LiteAgent` internally, which provides a simpler execution flow while preserving all of the agent's configuration (role, goal, backstory, tools, etc.).
|
||||
</Note>
|
||||
|
||||
## Important Considerations and Best Practices
|
||||
|
||||
### Security and Code Execution
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
---
|
||||
title: 'Agent Repositories'
|
||||
description: 'Learn how to use Agent Repositories to share and reuse your agents across teams and projects'
|
||||
icon: 'database'
|
||||
---
|
||||
|
||||
Agent Repositories allow enterprise users to store, share, and reuse agent definitions across teams and projects. This feature enables organizations to maintain a centralized library of standardized agents, promoting consistency and reducing duplication of effort.
|
||||
|
||||
## Benefits of Agent Repositories
|
||||
|
||||
- **Standardization**: Maintain consistent agent definitions across your organization
|
||||
- **Reusability**: Create an agent once and use it in multiple crews and projects
|
||||
- **Governance**: Implement organization-wide policies for agent configurations
|
||||
- **Collaboration**: Enable teams to share and build upon each other's work
|
||||
|
||||
## Using Agent Repositories
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. You must have an account at CrewAI, try the [free plan](https://app.crewai.com).
|
||||
2. You need to be authenticated using the CrewAI CLI.
|
||||
3. If you have more than one organization, make sure you are switched to the correct organization using the CLI command:
|
||||
|
||||
```bash
|
||||
crewai org switch <org_id>
|
||||
```
|
||||
|
||||
### Creating and Managing Agents in Repositories
|
||||
|
||||
To create and manage agents in repositories,Enterprise Dashboard.
|
||||
|
||||
### Loading Agents from Repositories
|
||||
|
||||
You can load agents from repositories in your code using the `from_repository` parameter:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
|
||||
# Create an agent by loading it from a repository
|
||||
# The agent is loaded with all its predefined configurations
|
||||
researcher = Agent(
|
||||
from_repository="market-research-agent"
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Overriding Repository Settings
|
||||
|
||||
You can override specific settings from the repository by providing them in the configuration:
|
||||
|
||||
```python
|
||||
researcher = Agent(
|
||||
from_repository="market-research-agent",
|
||||
goal="Research the latest trends in AI development", # Override the repository goal
|
||||
verbose=True # Add a setting not in the repository
|
||||
)
|
||||
```
|
||||
|
||||
### Example: Creating a Crew with Repository Agents
|
||||
|
||||
```python
|
||||
from crewai import Crew, Agent, Task
|
||||
|
||||
# Load agents from repositories
|
||||
researcher = Agent(
|
||||
from_repository="market-research-agent"
|
||||
)
|
||||
|
||||
writer = Agent(
|
||||
from_repository="content-writer-agent"
|
||||
)
|
||||
|
||||
# Create tasks
|
||||
research_task = Task(
|
||||
description="Research the latest trends in AI",
|
||||
agent=researcher
|
||||
)
|
||||
|
||||
writing_task = Task(
|
||||
description="Write a comprehensive report based on the research",
|
||||
agent=writer
|
||||
)
|
||||
|
||||
# Create the crew
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[research_task, writing_task],
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Run the crew
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
### Example: Using `kickoff()` with Repository Agents
|
||||
|
||||
You can also use repository agents directly with the `kickoff()` method for simpler interactions:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
# Define a structured output format
|
||||
class MarketAnalysis(BaseModel):
|
||||
key_trends: List[str]
|
||||
opportunities: List[str]
|
||||
recommendation: str
|
||||
|
||||
# Load an agent from repository
|
||||
analyst = Agent(
|
||||
from_repository="market-analyst-agent",
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Get a free-form response
|
||||
result = analyst.kickoff("Analyze the AI market in 2025")
|
||||
print(result.raw) # Access the raw response
|
||||
|
||||
# Get structured output
|
||||
structured_result = analyst.kickoff(
|
||||
"Provide a structured analysis of the AI market in 2025",
|
||||
response_format=MarketAnalysis
|
||||
)
|
||||
|
||||
# Access structured data
|
||||
print(f"Key Trends: {structured_result.pydantic.key_trends}")
|
||||
print(f"Recommendation: {structured_result.pydantic.recommendation}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Naming Convention**: Use clear, descriptive names for your repository agents
|
||||
2. **Documentation**: Include comprehensive descriptions for each agent
|
||||
3. **Tool Management**: Ensure that tools referenced by repository agents are available in your environment
|
||||
4. **Access Control**: Manage permissions to ensure only authorized team members can modify repository agents
|
||||
|
||||
## Organization Management
|
||||
|
||||
To switch between organizations or see your current organization, use the CrewAI CLI:
|
||||
|
||||
```bash
|
||||
# View current organization
|
||||
crewai org current
|
||||
|
||||
# Switch to a different organization
|
||||
crewai org switch <org_id>
|
||||
|
||||
# List all available organizations
|
||||
crewai org list
|
||||
```
|
||||
|
||||
<Note>
|
||||
When loading agents from repositories, you must be authenticated and switched to the correct organization. If you receive errors, check your authentication status and organization settings using the CLI commands above.
|
||||
</Note>
|
||||
216
docs/oauth2_llm_providers.md
Normal file
216
docs/oauth2_llm_providers.md
Normal file
@@ -0,0 +1,216 @@
|
||||
# OAuth2 LLM Providers
|
||||
|
||||
CrewAI supports OAuth2 authentication for custom LiteLLM providers through configuration files.
|
||||
|
||||
## Features
|
||||
|
||||
- **Automatic Token Management**: Handles OAuth2 token acquisition and refresh automatically
|
||||
- **Multiple Provider Support**: Configure multiple OAuth2 providers in a single configuration file
|
||||
- **Secure Credential Storage**: Keep OAuth2 credentials separate from your code
|
||||
- **Seamless Integration**: Works with existing LiteLLM provider configurations
|
||||
- **Error Handling**: Comprehensive error handling for authentication failures
|
||||
- **Retry Mechanism**: Automatic retry with exponential backoff for token acquisition
|
||||
- **Thread Safety**: Concurrent access protection for token caching
|
||||
- **Configuration Validation**: Automatic validation of URLs and scope formats
|
||||
|
||||
## Error Handling
|
||||
|
||||
The OAuth2 implementation provides specific error classes for different failure scenarios:
|
||||
|
||||
### OAuth2ConfigurationError
|
||||
Raised when there are issues with the OAuth2 configuration:
|
||||
- Invalid configuration file format
|
||||
- Missing required fields
|
||||
- Unknown OAuth2 providers
|
||||
|
||||
### OAuth2AuthenticationError
|
||||
Raised when OAuth2 authentication fails:
|
||||
- Network errors during token acquisition
|
||||
- Invalid credentials
|
||||
- Token endpoint errors
|
||||
|
||||
### OAuth2ValidationError
|
||||
Raised when configuration validation fails:
|
||||
- Invalid URL formats
|
||||
- Malformed scope values
|
||||
|
||||
### Example Error Handling
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
from crewai.llms.oauth2_errors import OAuth2ConfigurationError, OAuth2AuthenticationError
|
||||
|
||||
try:
|
||||
llm = LLM(
|
||||
model="my_custom_provider/my-model",
|
||||
oauth2_config_path="./litellm_config.json"
|
||||
)
|
||||
result = llm.call("test message")
|
||||
except OAuth2ConfigurationError as e:
|
||||
print(f"Configuration error: {e}")
|
||||
except OAuth2AuthenticationError as e:
|
||||
print(f"Authentication failed: {e}")
|
||||
if e.original_error:
|
||||
print(f"Original error: {e.original_error}")
|
||||
```
|
||||
|
||||
## Configuration Validation
|
||||
|
||||
The OAuth2 configuration is automatically validated:
|
||||
|
||||
- **URL Validation**: `token_url` must be a valid HTTP/HTTPS URL
|
||||
- **Scope Validation**: `scope` cannot contain empty values when split by spaces
|
||||
- **Required Fields**: `client_id`, `client_secret`, `token_url`, and `provider_name` are required
|
||||
|
||||
## Retry Mechanism
|
||||
|
||||
Token acquisition includes automatic retry with exponential backoff:
|
||||
- 3 retry attempts by default
|
||||
- Exponential backoff: 1s, 2s, 4s delays
|
||||
- Detailed logging of retry attempts
|
||||
- Thread-safe token caching
|
||||
|
||||
## Configuration
|
||||
|
||||
Create a `litellm_config.json` file in your project directory:
|
||||
|
||||
```json
|
||||
{
|
||||
"oauth2_providers": {
|
||||
"my_custom_provider": {
|
||||
"client_id": "your_client_id",
|
||||
"client_secret": "your_client_secret",
|
||||
"token_url": "https://your-provider.com/oauth/token",
|
||||
"scope": "llm.read llm.write"
|
||||
},
|
||||
"another_provider": {
|
||||
"client_id": "another_client_id",
|
||||
"client_secret": "another_client_secret",
|
||||
"token_url": "https://another-provider.com/token"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# Initialize LLM with OAuth2 support
|
||||
llm = LLM(
|
||||
model="my_custom_provider/my-model",
|
||||
oauth2_config_path="./litellm_config.json" # Optional, defaults to ./litellm_config.json
|
||||
)
|
||||
|
||||
# Use in CrewAI
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
agent = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze data trends",
|
||||
backstory="Expert in data analysis",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Analyze the latest sales data",
|
||||
agent=agent
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
You can also use environment variables in your configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"oauth2_providers": {
|
||||
"my_provider": {
|
||||
"client_id": "os.environ/MY_CLIENT_ID",
|
||||
"client_secret": "os.environ/MY_CLIENT_SECRET",
|
||||
"token_url": "https://my-provider.com/token"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Supported OAuth2 Flow
|
||||
|
||||
Currently supports the **Client Credentials** OAuth2 flow, which is suitable for server-to-server authentication.
|
||||
|
||||
## Token Management
|
||||
|
||||
- Tokens are automatically cached and refreshed when they expire
|
||||
- A 60-second buffer is used before token expiration to ensure reliability
|
||||
- Failed token acquisition will raise a `RuntimeError` with details
|
||||
|
||||
## Configuration Schema
|
||||
|
||||
The `litellm_config.json` file should follow this schema:
|
||||
|
||||
```json
|
||||
{
|
||||
"oauth2_providers": {
|
||||
"<provider_name>": {
|
||||
"client_id": "string (required)",
|
||||
"client_secret": "string (required)",
|
||||
"token_url": "string (required)",
|
||||
"scope": "string (optional)",
|
||||
"refresh_token": "string (optional)"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Legacy Error Handling
|
||||
|
||||
For backward compatibility, the following error types are still supported:
|
||||
- OAuth2 authentication failures raise `OAuth2AuthenticationError` (previously `RuntimeError`)
|
||||
- Invalid configuration files raise `OAuth2ConfigurationError` (previously `ValueError`)
|
||||
- Network errors during token acquisition are wrapped in `OAuth2AuthenticationError`
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic OAuth2 Provider
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="my_provider/gpt-4",
|
||||
oauth2_config_path="./config.json"
|
||||
)
|
||||
|
||||
response = llm.call("Hello, world!")
|
||||
print(response)
|
||||
```
|
||||
|
||||
### Multiple Providers
|
||||
|
||||
```json
|
||||
{
|
||||
"oauth2_providers": {
|
||||
"provider_a": {
|
||||
"client_id": "client_a",
|
||||
"client_secret": "secret_a",
|
||||
"token_url": "https://provider-a.com/token"
|
||||
},
|
||||
"provider_b": {
|
||||
"client_id": "client_b",
|
||||
"client_secret": "secret_b",
|
||||
"token_url": "https://provider-b.com/oauth/token",
|
||||
"scope": "read write"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
```python
|
||||
# Use different providers
|
||||
llm_a = LLM(model="provider_a/model-1", oauth2_config_path="./config.json")
|
||||
llm_b = LLM(model="provider_b/model-2", oauth2_config_path="./config.json")
|
||||
```
|
||||
64
examples/oauth2_llm_example.py
Normal file
64
examples/oauth2_llm_example.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Example demonstrating OAuth2 authentication with custom LLM providers in CrewAI.
|
||||
|
||||
This example shows how to configure and use OAuth2-authenticated LLM providers.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from crewai import Agent, Task, Crew, LLM
|
||||
|
||||
def create_example_config():
|
||||
"""Create an example OAuth2 configuration file."""
|
||||
config = {
|
||||
"oauth2_providers": {
|
||||
"my_custom_provider": {
|
||||
"client_id": "your_client_id_here",
|
||||
"client_secret": "your_client_secret_here",
|
||||
"token_url": "https://your-provider.com/oauth/token",
|
||||
"scope": "llm.read llm.write"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config_path = Path("example_oauth2_config.json")
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print(f"Created example config at {config_path}")
|
||||
return config_path
|
||||
|
||||
def main():
|
||||
config_path = create_example_config()
|
||||
|
||||
try:
|
||||
llm = LLM(
|
||||
model="my_custom_provider/my-model",
|
||||
oauth2_config_path=str(config_path)
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Provide helpful research insights",
|
||||
backstory="An AI assistant specialized in research and analysis",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research the benefits of OAuth2 authentication in AI systems",
|
||||
agent=agent,
|
||||
expected_output="A comprehensive summary of OAuth2 benefits"
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
|
||||
print("Running crew with OAuth2-authenticated LLM...")
|
||||
result = crew.kickoff()
|
||||
print(f"Result: {result}")
|
||||
|
||||
finally:
|
||||
if config_path.exists():
|
||||
config_path.unlink()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -18,11 +18,6 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -621,11 +616,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> CrewOutput:
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
token = attach(ctx)
|
||||
|
||||
try:
|
||||
for before_callback in self.before_kickoff_callbacks:
|
||||
if inputs is None:
|
||||
@@ -686,8 +676,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewKickoffFailedEvent(error=str(e), crew_name=self.name or "crew"),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input in the list and aggregates results."""
|
||||
|
||||
@@ -52,6 +52,9 @@ import io
|
||||
from typing import TextIO
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.oauth2_config import OAuth2ConfigLoader
|
||||
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
|
||||
from crewai.llms.oauth2_errors import OAuth2ConfigurationError
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
@@ -311,6 +314,7 @@ class LLM(BaseLLM):
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
stream: bool = False,
|
||||
oauth2_config_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
@@ -338,6 +342,10 @@ class LLM(BaseLLM):
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
|
||||
self.oauth2_config_loader = OAuth2ConfigLoader(oauth2_config_path)
|
||||
self.oauth2_token_manager = OAuth2TokenManager()
|
||||
self.oauth2_configs = self.oauth2_config_loader.load_config()
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
@@ -384,7 +392,21 @@ class LLM(BaseLLM):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
formatted_messages = self._format_messages_for_provider(messages)
|
||||
|
||||
# --- 2) Prepare the parameters for the completion call
|
||||
api_key = self.api_key
|
||||
provider = self._get_custom_llm_provider()
|
||||
|
||||
if provider and provider in self.oauth2_configs:
|
||||
oauth2_config = self.oauth2_configs[provider]
|
||||
try:
|
||||
self._validate_oauth2_provider(provider)
|
||||
access_token = self.oauth2_token_manager.get_access_token(oauth2_config)
|
||||
api_key = access_token
|
||||
logging.debug(f"Using OAuth2 authentication for provider {provider}")
|
||||
except Exception as e:
|
||||
logging.error(f"OAuth2 authentication failed for provider {provider}: {e}")
|
||||
raise
|
||||
|
||||
# --- 3) Prepare the parameters for the completion call
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
@@ -404,7 +426,7 @@ class LLM(BaseLLM):
|
||||
"api_base": self.api_base,
|
||||
"base_url": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"api_key": api_key,
|
||||
"stream": self.stream,
|
||||
"tools": tools,
|
||||
"reasoning_effort": self.reasoning_effort,
|
||||
@@ -1076,6 +1098,12 @@ class LLM(BaseLLM):
|
||||
return self.model.split("/")[0]
|
||||
return None
|
||||
|
||||
def _validate_oauth2_provider(self, provider: str) -> bool:
|
||||
"""Validate that OAuth2 provider exists in configuration"""
|
||||
if provider not in self.oauth2_configs:
|
||||
raise OAuth2ConfigurationError(f"Unknown OAuth2 provider: {provider}")
|
||||
return True
|
||||
|
||||
def _validate_call_params(self) -> None:
|
||||
"""
|
||||
Validate parameters before making a call. Currently this only checks if
|
||||
|
||||
@@ -1 +1,17 @@
|
||||
"""LLM implementations for crewAI."""
|
||||
|
||||
from .base_llm import BaseLLM
|
||||
from .oauth2_config import OAuth2Config, OAuth2ConfigLoader
|
||||
from .oauth2_token_manager import OAuth2TokenManager
|
||||
from .oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError
|
||||
|
||||
__all__ = [
|
||||
"BaseLLM",
|
||||
"OAuth2Config",
|
||||
"OAuth2ConfigLoader",
|
||||
"OAuth2TokenManager",
|
||||
"OAuth2Error",
|
||||
"OAuth2ConfigurationError",
|
||||
"OAuth2AuthenticationError",
|
||||
"OAuth2ValidationError"
|
||||
]
|
||||
|
||||
65
src/crewai/llms/oauth2_config.py
Normal file
65
src/crewai/llms/oauth2_config.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
import json
|
||||
import re
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from .oauth2_errors import OAuth2ConfigurationError, OAuth2ValidationError
|
||||
|
||||
|
||||
class OAuth2Config(BaseModel):
|
||||
client_id: str = Field(description="OAuth2 client ID")
|
||||
client_secret: str = Field(description="OAuth2 client secret")
|
||||
token_url: str = Field(description="OAuth2 token endpoint URL")
|
||||
scope: Optional[str] = Field(default=None, description="OAuth2 scope")
|
||||
provider_name: str = Field(description="Custom provider name")
|
||||
refresh_token: Optional[str] = Field(default=None, description="OAuth2 refresh token")
|
||||
|
||||
@field_validator('token_url')
|
||||
@classmethod
|
||||
def validate_token_url(cls, v: str) -> str:
|
||||
"""Validate that token_url is a valid HTTP/HTTPS URL."""
|
||||
url_pattern = re.compile(
|
||||
r'^https?://' # http:// or https://
|
||||
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+[A-Z]{2,6}\.?|' # domain...
|
||||
r'localhost|' # localhost...
|
||||
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
|
||||
r'(?::\d+)?' # optional port
|
||||
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
|
||||
|
||||
if not url_pattern.match(v):
|
||||
raise OAuth2ValidationError(f"Invalid token URL format: {v}")
|
||||
return v
|
||||
|
||||
@field_validator('scope')
|
||||
@classmethod
|
||||
def validate_scope(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate OAuth2 scope format."""
|
||||
if v:
|
||||
if ' ' in v:
|
||||
raise OAuth2ValidationError("Invalid scope format: scope cannot contain empty values")
|
||||
return v
|
||||
|
||||
|
||||
class OAuth2ConfigLoader:
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
self.config_path = Path(config_path) if config_path else Path("litellm_config.json")
|
||||
|
||||
def load_config(self) -> Dict[str, OAuth2Config]:
|
||||
"""Load OAuth2 configurations from litellm_config.json"""
|
||||
if not self.config_path.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(self.config_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
oauth2_configs = {}
|
||||
for provider_name, config_data in data.get("oauth2_providers", {}).items():
|
||||
oauth2_configs[provider_name] = OAuth2Config(
|
||||
provider_name=provider_name,
|
||||
**config_data
|
||||
)
|
||||
|
||||
return oauth2_configs
|
||||
except (json.JSONDecodeError, KeyError, ValueError) as e:
|
||||
raise OAuth2ConfigurationError(f"Invalid OAuth2 configuration in {self.config_path}: {e}")
|
||||
26
src/crewai/llms/oauth2_errors.py
Normal file
26
src/crewai/llms/oauth2_errors.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""OAuth2 error classes for CrewAI."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class OAuth2Error(Exception):
|
||||
"""Base exception class for OAuth2 operation errors."""
|
||||
|
||||
def __init__(self, message: str, original_error: Optional[Exception] = None):
|
||||
super().__init__(message)
|
||||
self.original_error = original_error
|
||||
|
||||
|
||||
class OAuth2ConfigurationError(OAuth2Error):
|
||||
"""Exception raised for OAuth2 configuration errors."""
|
||||
pass
|
||||
|
||||
|
||||
class OAuth2AuthenticationError(OAuth2Error):
|
||||
"""Exception raised for OAuth2 authentication failures."""
|
||||
pass
|
||||
|
||||
|
||||
class OAuth2ValidationError(OAuth2Error):
|
||||
"""Exception raised for OAuth2 validation errors."""
|
||||
pass
|
||||
93
src/crewai/llms/oauth2_token_manager.py
Normal file
93
src/crewai/llms/oauth2_token_manager.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import time
|
||||
import logging
|
||||
import requests
|
||||
from threading import Lock
|
||||
from typing import Dict, Any
|
||||
from .oauth2_config import OAuth2Config
|
||||
from .oauth2_errors import OAuth2AuthenticationError
|
||||
|
||||
|
||||
class OAuth2TokenManager:
|
||||
def __init__(self):
|
||||
self._tokens: Dict[str, Dict[str, Any]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def get_access_token(self, config: OAuth2Config) -> str:
|
||||
"""Get valid access token for the provider, refreshing if necessary"""
|
||||
with self._lock:
|
||||
return self._get_access_token_internal(config)
|
||||
|
||||
def _get_access_token_internal(self, config: OAuth2Config) -> str:
|
||||
"""Internal method to get access token (called within lock)"""
|
||||
provider_name = config.provider_name
|
||||
|
||||
if provider_name in self._tokens:
|
||||
token_data = self._tokens[provider_name]
|
||||
if self._is_token_valid(token_data):
|
||||
logging.debug(f"Using cached OAuth2 token for provider {provider_name}")
|
||||
return token_data["access_token"]
|
||||
|
||||
logging.info(f"Acquiring new OAuth2 token for provider {provider_name}")
|
||||
return self._acquire_new_token(config)
|
||||
|
||||
def _is_token_valid(self, token_data: Dict[str, Any]) -> bool:
|
||||
"""Check if token is still valid (not expired)"""
|
||||
if "expires_at" not in token_data:
|
||||
return False
|
||||
|
||||
return time.time() < (token_data["expires_at"] - 60)
|
||||
|
||||
def _acquire_new_token(self, config: OAuth2Config, retry_count: int = 3) -> str:
|
||||
"""Acquire new access token using client credentials flow with retry logic"""
|
||||
for attempt in range(retry_count):
|
||||
try:
|
||||
return self._perform_token_request(config)
|
||||
except requests.RequestException as e:
|
||||
if attempt == retry_count - 1:
|
||||
raise OAuth2AuthenticationError(
|
||||
f"Failed to acquire OAuth2 token for {config.provider_name} after {retry_count} attempts: {e}",
|
||||
original_error=e
|
||||
)
|
||||
wait_time = 2 ** attempt
|
||||
logging.warning(f"OAuth2 token request failed for {config.provider_name}, retrying in {wait_time}s (attempt {attempt + 1}/{retry_count}): {e}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
raise OAuth2AuthenticationError(f"Unexpected error in token acquisition for {config.provider_name}")
|
||||
|
||||
def _perform_token_request(self, config: OAuth2Config) -> str:
|
||||
"""Perform the actual token request"""
|
||||
payload = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": config.client_id,
|
||||
"client_secret": config.client_secret,
|
||||
}
|
||||
|
||||
if config.scope:
|
||||
payload["scope"] = config.scope
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
config.token_url,
|
||||
data=payload,
|
||||
timeout=30,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
token_data = response.json()
|
||||
access_token = token_data["access_token"]
|
||||
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
self._tokens[config.provider_name] = {
|
||||
"access_token": access_token,
|
||||
"expires_at": time.time() + expires_in,
|
||||
"token_type": token_data.get("token_type", "Bearer")
|
||||
}
|
||||
|
||||
logging.info(f"Successfully acquired OAuth2 token for {config.provider_name}, expires in {expires_in}s")
|
||||
return access_token
|
||||
|
||||
except requests.RequestException:
|
||||
raise
|
||||
except KeyError as e:
|
||||
raise OAuth2AuthenticationError(f"Invalid token response from {config.provider_name}: missing {e}")
|
||||
@@ -1 +0,0 @@
|
||||
"""Crew-specific utilities."""
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Context management utilities for tracking crew and task execution context using OpenTelemetry baggage."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from opentelemetry import baggage
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
|
||||
def get_crew_context() -> Optional[CrewContext]:
|
||||
"""Get the current crew context from OpenTelemetry baggage.
|
||||
|
||||
Returns:
|
||||
CrewContext instance containing crew context information, or None if no context is set
|
||||
"""
|
||||
return baggage.get_baggage("crew_context")
|
||||
@@ -1,16 +0,0 @@
|
||||
"""Models for crew-related data structures."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CrewContext(BaseModel):
|
||||
"""Model representing crew context information."""
|
||||
|
||||
id: Optional[str] = Field(
|
||||
default=None, description="Unique identifier for the crew"
|
||||
)
|
||||
key: Optional[str] = Field(
|
||||
default=None, description="Optional crew key/name for identification"
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
from inspect import getsource
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from crewai.utilities.events.base_events import BaseEvent
|
||||
@@ -17,26 +16,23 @@ class LLMGuardrailStartedEvent(BaseEvent):
|
||||
retry_count: int
|
||||
|
||||
def __init__(self, **data):
|
||||
from inspect import getsource
|
||||
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)):
|
||||
if isinstance(self.guardrail, LLMGuardrail) or isinstance(
|
||||
self.guardrail, HallucinationGuardrail
|
||||
):
|
||||
self.guardrail = self.guardrail.description.strip()
|
||||
elif isinstance(self.guardrail, Callable):
|
||||
self.guardrail = getsource(self.guardrail).strip()
|
||||
|
||||
|
||||
class LLMGuardrailCompletedEvent(BaseEvent):
|
||||
"""Event emitted when a guardrail task completes
|
||||
|
||||
Attributes:
|
||||
success: Whether the guardrail validation passed
|
||||
result: The validation result
|
||||
error: Error message if validation failed
|
||||
retry_count: The number of times the guardrail has been retried
|
||||
"""
|
||||
"""Event emitted when a guardrail task completes"""
|
||||
|
||||
type: str = "llm_guardrail_completed"
|
||||
success: bool
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.utilities.crew.crew_context import get_crew_context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_agent_factory():
|
||||
def create_agent(name: str) -> Agent:
|
||||
return Agent(
|
||||
role=f"{name} Agent",
|
||||
goal=f"Complete {name} task",
|
||||
backstory=f"I am agent for {name}",
|
||||
)
|
||||
|
||||
return create_agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task_factory():
|
||||
def create_task(name: str, callback: Callable = None) -> Task:
|
||||
return Task(
|
||||
description=f"Task for {name}", expected_output="Done", callback=callback
|
||||
)
|
||||
|
||||
return create_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def crew_factory(simple_agent_factory, simple_task_factory):
|
||||
def create_crew(name: str, task_callback: Callable = None) -> Crew:
|
||||
agent = simple_agent_factory(name)
|
||||
task = simple_task_factory(name, callback=task_callback)
|
||||
task.agent = agent
|
||||
|
||||
return Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
return create_crew
|
||||
|
||||
|
||||
class TestCrewThreadSafety:
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_parallel_crews_thread_safety(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
num_crews = 5
|
||||
|
||||
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
|
||||
results = {"crew_id": crew_id, "contexts": []}
|
||||
|
||||
def check_context_task(output):
|
||||
context = get_crew_context()
|
||||
results["contexts"].append(
|
||||
{
|
||||
"stage": "task_callback",
|
||||
"crew_id": context.id if context else None,
|
||||
"crew_key": context.key if context else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
context_before = get_crew_context()
|
||||
results["contexts"].append(
|
||||
{
|
||||
"stage": "before_kickoff",
|
||||
"crew_id": context_before.id if context_before else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
|
||||
crew = crew_factory(crew_id, task_callback=check_context_task)
|
||||
output = crew.kickoff()
|
||||
|
||||
context_after = get_crew_context()
|
||||
results["contexts"].append(
|
||||
{
|
||||
"stage": "after_kickoff",
|
||||
"crew_id": context_after.id if context_after else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
|
||||
results["crew_uuid"] = str(crew.id)
|
||||
results["output"] = output.raw
|
||||
|
||||
return results
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_crews) as executor:
|
||||
futures = []
|
||||
for i in range(num_crews):
|
||||
future = executor.submit(run_crew_with_context_check, f"crew_{i}")
|
||||
futures.append(future)
|
||||
|
||||
results = [f.result() for f in futures]
|
||||
|
||||
for result in results:
|
||||
crew_uuid = result["crew_uuid"]
|
||||
|
||||
before_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
|
||||
)
|
||||
assert (
|
||||
before_ctx["crew_id"] is None
|
||||
), f"Context should be None before kickoff for {result['crew_id']}"
|
||||
|
||||
task_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
|
||||
)
|
||||
assert (
|
||||
task_ctx["crew_id"] == crew_uuid
|
||||
), f"Context mismatch during task for {result['crew_id']}"
|
||||
|
||||
after_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
|
||||
)
|
||||
assert (
|
||||
after_ctx["crew_id"] is None
|
||||
), f"Context should be None after kickoff for {result['crew_id']}"
|
||||
|
||||
thread_name = before_ctx["thread"]
|
||||
assert (
|
||||
"ThreadPoolExecutor" in thread_name
|
||||
), f"Should run in thread pool for {result['crew_id']}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.execute_task")
|
||||
async def test_async_crews_thread_safety(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
num_crews = 5
|
||||
|
||||
async def run_crew_async(crew_id: str) -> Dict[str, Any]:
|
||||
task_context = {"crew_id": crew_id, "context": None}
|
||||
|
||||
def capture_context(output):
|
||||
ctx = get_crew_context()
|
||||
task_context["context"] = {
|
||||
"crew_id": ctx.id if ctx else None,
|
||||
"crew_key": ctx.key if ctx else None,
|
||||
}
|
||||
return output
|
||||
|
||||
crew = crew_factory(crew_id, task_callback=capture_context)
|
||||
output = await crew.kickoff_async()
|
||||
|
||||
return {
|
||||
"crew_id": crew_id,
|
||||
"crew_uuid": str(crew.id),
|
||||
"output": output.raw,
|
||||
"task_context": task_context,
|
||||
}
|
||||
|
||||
tasks = [run_crew_async(f"async_crew_{i}") for i in range(num_crews)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
for result in results:
|
||||
crew_uuid = result["crew_uuid"]
|
||||
task_ctx = result["task_context"]["context"]
|
||||
|
||||
assert (
|
||||
task_ctx is not None
|
||||
), f"Context should exist during task for {result['crew_id']}"
|
||||
assert (
|
||||
task_ctx["crew_id"] == crew_uuid
|
||||
), f"Context mismatch for {result['crew_id']}"
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
contexts_captured = []
|
||||
|
||||
def capture_context(output):
|
||||
ctx = get_crew_context()
|
||||
contexts_captured.append(
|
||||
{
|
||||
"context_id": ctx.id if ctx else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
crew = crew_factory("for_each_test", task_callback=capture_context)
|
||||
inputs = [{"item": f"input_{i}"} for i in range(3)]
|
||||
|
||||
results = crew.kickoff_for_each(inputs=inputs)
|
||||
|
||||
assert len(results) == len(inputs)
|
||||
assert len(contexts_captured) == len(inputs)
|
||||
|
||||
context_ids = [ctx["context_id"] for ctx in contexts_captured]
|
||||
assert len(set(context_ids)) == len(
|
||||
inputs
|
||||
), "Each execution should have unique context"
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
contexts = []
|
||||
|
||||
def check_context(output):
|
||||
ctx = get_crew_context()
|
||||
contexts.append(
|
||||
{
|
||||
"context_id": ctx.id if ctx else None,
|
||||
"context_key": ctx.key if ctx else None,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def run_crew(name: str):
|
||||
crew = crew_factory(name, task_callback=check_context)
|
||||
crew.kickoff()
|
||||
return str(crew.id)
|
||||
|
||||
crew1_id = run_crew("First")
|
||||
crew2_id = run_crew("Second")
|
||||
|
||||
assert len(contexts) == 2
|
||||
assert contexts[0]["context_id"] == crew1_id
|
||||
assert contexts[1]["context_id"] == crew2_id
|
||||
assert contexts[0]["context_id"] != contexts[1]["context_id"]
|
||||
420
tests/test_oauth2_llm.py
Normal file
420
tests/test_oauth2_llm.py
Normal file
@@ -0,0 +1,420 @@
|
||||
import json
|
||||
import pytest
|
||||
import requests
|
||||
import tempfile
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.oauth2_config import OAuth2Config, OAuth2ConfigLoader
|
||||
from crewai.llms.oauth2_token_manager import OAuth2TokenManager
|
||||
from crewai.llms.oauth2_errors import OAuth2Error, OAuth2ConfigurationError, OAuth2AuthenticationError, OAuth2ValidationError
|
||||
|
||||
|
||||
class TestOAuth2Config:
|
||||
def test_oauth2_config_creation(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider",
|
||||
scope="read write"
|
||||
)
|
||||
assert config.client_id == "test_client"
|
||||
assert config.provider_name == "test_provider"
|
||||
assert config.scope == "read write"
|
||||
|
||||
def test_oauth2_config_loader_missing_file(self):
|
||||
loader = OAuth2ConfigLoader("nonexistent.json")
|
||||
configs = loader.load_config()
|
||||
assert configs == {}
|
||||
|
||||
def test_oauth2_config_loader_valid_config(self):
|
||||
config_data = {
|
||||
"oauth2_providers": {
|
||||
"custom_provider": {
|
||||
"client_id": "test_client",
|
||||
"client_secret": "test_secret",
|
||||
"token_url": "https://example.com/token",
|
||||
"scope": "read"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
loader = OAuth2ConfigLoader(config_path)
|
||||
configs = loader.load_config()
|
||||
|
||||
assert "custom_provider" in configs
|
||||
config = configs["custom_provider"]
|
||||
assert config.client_id == "test_client"
|
||||
assert config.provider_name == "custom_provider"
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
def test_oauth2_config_loader_invalid_json(self):
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
f.write("invalid json")
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
loader = OAuth2ConfigLoader(config_path)
|
||||
with pytest.raises(OAuth2ConfigurationError, match="Invalid OAuth2 configuration"):
|
||||
loader.load_config()
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
|
||||
class TestOAuth2TokenManager:
|
||||
def test_token_acquisition_success(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "test_token_123",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response) as mock_post:
|
||||
manager = OAuth2TokenManager()
|
||||
token = manager.get_access_token(config)
|
||||
|
||||
assert token == "test_token_123"
|
||||
mock_post.assert_called_once()
|
||||
|
||||
def test_token_caching(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "test_token_123",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response) as mock_post:
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
token1 = manager.get_access_token(config)
|
||||
assert token1 == "test_token_123"
|
||||
assert mock_post.call_count == 1
|
||||
|
||||
token2 = manager.get_access_token(config)
|
||||
assert token2 == "test_token_123"
|
||||
assert mock_post.call_count == 1
|
||||
|
||||
def test_token_refresh_on_expiry(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new_token_456",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600
|
||||
}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
manager._tokens["test_provider"] = {
|
||||
"access_token": "old_token",
|
||||
"expires_at": time.time() - 100
|
||||
}
|
||||
|
||||
token = manager.get_access_token(config)
|
||||
assert token == "new_token_456"
|
||||
|
||||
def test_token_acquisition_failure(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
with patch('requests.post', side_effect=requests.RequestException("Network error")):
|
||||
with patch('time.sleep'):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
def test_invalid_token_response(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"invalid": "response"}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
with patch('requests.post', return_value=mock_response):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Invalid token response"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
|
||||
class TestLLMOAuth2Integration:
|
||||
def test_llm_with_oauth2_config(self):
|
||||
config_data = {
|
||||
"oauth2_providers": {
|
||||
"custom": {
|
||||
"client_id": "test_client",
|
||||
"client_secret": "test_secret",
|
||||
"token_url": "https://example.com/token"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
llm = LLM(
|
||||
model="custom/test-model",
|
||||
oauth2_config_path=config_path
|
||||
)
|
||||
|
||||
assert "custom" in llm.oauth2_configs
|
||||
assert llm.oauth2_configs["custom"].client_id == "test_client"
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
def test_llm_without_oauth2_config(self):
|
||||
llm = LLM(model="openai/gpt-3.5-turbo")
|
||||
assert llm.oauth2_configs == {}
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_llm_oauth2_token_injection(self, mock_completion):
|
||||
config_data = {
|
||||
"oauth2_providers": {
|
||||
"custom": {
|
||||
"client_id": "test_client",
|
||||
"client_secret": "test_secret",
|
||||
"token_url": "https://example.com/token"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
with patch.object(OAuth2TokenManager, 'get_access_token', return_value="oauth_token_123"):
|
||||
llm = LLM(
|
||||
model="custom/test-model",
|
||||
oauth2_config_path=config_path
|
||||
)
|
||||
|
||||
llm.call("test message")
|
||||
|
||||
call_args = mock_completion.call_args
|
||||
assert call_args[1]['api_key'] == "oauth_token_123"
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_llm_oauth2_authentication_failure(self, mock_completion):
|
||||
config_data = {
|
||||
"oauth2_providers": {
|
||||
"custom": {
|
||||
"client_id": "test_client",
|
||||
"client_secret": "test_secret",
|
||||
"token_url": "https://example.com/token"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
json.dump(config_data, f)
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
with patch.object(OAuth2TokenManager, 'get_access_token', side_effect=OAuth2AuthenticationError("Auth failed")):
|
||||
llm = LLM(
|
||||
model="custom/test-model",
|
||||
oauth2_config_path=config_path
|
||||
)
|
||||
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Auth failed"):
|
||||
llm.call("test message")
|
||||
finally:
|
||||
Path(config_path).unlink()
|
||||
|
||||
@patch('crewai.llm.litellm.completion')
|
||||
def test_llm_non_oauth2_provider_unchanged(self, mock_completion):
|
||||
mock_completion.return_value = Mock(choices=[Mock(message=Mock(content="test response"))])
|
||||
|
||||
llm = LLM(
|
||||
model="openai/gpt-3.5-turbo",
|
||||
api_key="original_key"
|
||||
)
|
||||
|
||||
llm.call("test message")
|
||||
|
||||
call_args = mock_completion.call_args
|
||||
assert call_args[1]['api_key'] == "original_key"
|
||||
|
||||
|
||||
class TestOAuth2Validation:
|
||||
def test_valid_url_validation(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
assert config.token_url == "https://example.com/token"
|
||||
|
||||
def test_invalid_url_validation(self):
|
||||
with pytest.raises(OAuth2ValidationError, match="Invalid token URL format"):
|
||||
OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="not-a-url",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
def test_valid_scope_validation(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider",
|
||||
scope="read write"
|
||||
)
|
||||
assert config.scope == "read write"
|
||||
|
||||
def test_invalid_scope_validation(self):
|
||||
with pytest.raises(OAuth2ValidationError, match="Invalid scope format"):
|
||||
OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider",
|
||||
scope="read write"
|
||||
)
|
||||
|
||||
|
||||
class TestOAuth2RetryMechanism:
|
||||
def test_retry_mechanism_success_on_second_attempt(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
responses = [
|
||||
requests.RequestException("Network error"),
|
||||
Mock(json=lambda: {"access_token": "token_123", "expires_in": 3600}, raise_for_status=lambda: None)
|
||||
]
|
||||
|
||||
with patch('requests.post', side_effect=responses):
|
||||
with patch('time.sleep'):
|
||||
manager = OAuth2TokenManager()
|
||||
token = manager.get_access_token(config)
|
||||
assert token == "token_123"
|
||||
|
||||
def test_retry_mechanism_exhausted(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
with patch('requests.post', side_effect=requests.RequestException("Network error")):
|
||||
with patch('time.sleep'):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
with pytest.raises(OAuth2AuthenticationError, match="Failed to acquire OAuth2 token.*after 3 attempts"):
|
||||
manager.get_access_token(config)
|
||||
|
||||
|
||||
class TestOAuth2ThreadSafety:
|
||||
def test_concurrent_token_requests(self):
|
||||
config = OAuth2Config(
|
||||
client_id="test_client",
|
||||
client_secret="test_secret",
|
||||
token_url="https://example.com/token",
|
||||
provider_name="test_provider"
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {"access_token": "token_123", "expires_in": 3600}
|
||||
mock_response.raise_for_status.return_value = None
|
||||
|
||||
call_count = 0
|
||||
def mock_post(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1)
|
||||
return mock_response
|
||||
|
||||
with patch('requests.post', side_effect=mock_post):
|
||||
manager = OAuth2TokenManager()
|
||||
|
||||
results = []
|
||||
def get_token():
|
||||
token = manager.get_access_token(config)
|
||||
results.append(token)
|
||||
|
||||
threads = [threading.Thread(target=get_token) for _ in range(5)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert all(token == "token_123" for token in results)
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
class TestOAuth2ErrorClasses:
|
||||
def test_oauth2_configuration_error(self):
|
||||
error = OAuth2ConfigurationError("Config error")
|
||||
assert str(error) == "Config error"
|
||||
assert isinstance(error, OAuth2Error)
|
||||
|
||||
def test_oauth2_authentication_error_with_original(self):
|
||||
original = requests.RequestException("Network error")
|
||||
error = OAuth2AuthenticationError("Auth failed", original_error=original)
|
||||
assert str(error) == "Auth failed"
|
||||
assert error.original_error == original
|
||||
|
||||
def test_oauth2_validation_error(self):
|
||||
error = OAuth2ValidationError("Validation failed")
|
||||
assert str(error) == "Validation failed"
|
||||
assert isinstance(error, OAuth2Error)
|
||||
@@ -1,88 +0,0 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.crew_context import get_crew_context
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
|
||||
def test_crew_context_creation():
|
||||
crew_id = str(uuid.uuid4())
|
||||
context = CrewContext(id=crew_id, key="test-crew")
|
||||
assert context.id == crew_id
|
||||
assert context.key == "test-crew"
|
||||
|
||||
|
||||
def test_get_crew_context_with_baggage():
|
||||
crew_id = str(uuid.uuid4())
|
||||
assert get_crew_context() is None
|
||||
|
||||
crew_ctx = CrewContext(id=crew_id, key="test-key")
|
||||
ctx = baggage.set_baggage("crew_context", crew_ctx)
|
||||
token = attach(ctx)
|
||||
|
||||
try:
|
||||
context = get_crew_context()
|
||||
assert context is not None
|
||||
assert context.id == crew_id
|
||||
assert context.key == "test-key"
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
assert get_crew_context() is None
|
||||
|
||||
|
||||
def test_get_crew_context_empty():
|
||||
assert get_crew_context() is None
|
||||
|
||||
|
||||
def test_baggage_nested_contexts():
|
||||
crew_id1 = str(uuid.uuid4())
|
||||
crew_id2 = str(uuid.uuid4())
|
||||
|
||||
crew_ctx1 = CrewContext(id=crew_id1, key="outer")
|
||||
ctx1 = baggage.set_baggage("crew_context", crew_ctx1)
|
||||
token1 = attach(ctx1)
|
||||
|
||||
try:
|
||||
outer_context = get_crew_context()
|
||||
assert outer_context.id == crew_id1
|
||||
assert outer_context.key == "outer"
|
||||
|
||||
crew_ctx2 = CrewContext(id=crew_id2, key="inner")
|
||||
ctx2 = baggage.set_baggage("crew_context", crew_ctx2)
|
||||
token2 = attach(ctx2)
|
||||
|
||||
try:
|
||||
inner_context = get_crew_context()
|
||||
assert inner_context.id == crew_id2
|
||||
assert inner_context.key == "inner"
|
||||
finally:
|
||||
detach(token2)
|
||||
|
||||
restored_context = get_crew_context()
|
||||
assert restored_context.id == crew_id1
|
||||
assert restored_context.key == "outer"
|
||||
finally:
|
||||
detach(token1)
|
||||
|
||||
assert get_crew_context() is None
|
||||
|
||||
|
||||
def test_baggage_exception_handling():
|
||||
crew_id = str(uuid.uuid4())
|
||||
|
||||
crew_ctx = CrewContext(id=crew_id, key="test")
|
||||
ctx = baggage.set_baggage("crew_context", crew_ctx)
|
||||
token = attach(ctx)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
try:
|
||||
assert get_crew_context() is not None
|
||||
raise ValueError("Test exception")
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
assert get_crew_context() is None
|
||||
Reference in New Issue
Block a user