mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Compare commits
8 Commits
1.6.1
...
devin/1741
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73880d407b | ||
|
|
2a573d8df9 | ||
|
|
1b8c07760e | ||
|
|
963ed23b63 | ||
|
|
22aeeaadbe | ||
|
|
7201161207 | ||
|
|
687303ad63 | ||
|
|
ec8e705bbc |
681
docs/custom_llm.md
Normal file
681
docs/custom_llm.md
Normal file
@@ -0,0 +1,681 @@
|
||||
# Custom LLM Implementations
|
||||
|
||||
CrewAI supports custom LLM implementations through the `LLM` base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
|
||||
|
||||
## Using Custom LLM Implementations
|
||||
|
||||
To create a custom LLM implementation, you need to:
|
||||
|
||||
1. Inherit from the `LLM` base class
|
||||
2. Implement the required methods:
|
||||
- `call()`: The main method to call the LLM with messages
|
||||
- `supports_function_calling()`: Whether the LLM supports function calling
|
||||
- `supports_stop_words()`: Whether the LLM supports stop words
|
||||
- `get_context_window_size()`: The context window size of the LLM
|
||||
|
||||
## Using the Default LLM Implementation
|
||||
|
||||
If you don't need a custom LLM implementation, you can use the default implementation provided by CrewAI:
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# Create a default LLM instance
|
||||
llm = LLM.create(model="gpt-4")
|
||||
|
||||
# Or with more parameters
|
||||
llm = LLM.create(
|
||||
model="gpt-4",
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
api_key="your-api-key"
|
||||
)
|
||||
```
|
||||
|
||||
## Example: Basic Custom LLM
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class CustomLLM(LLM):
|
||||
def __init__(self, api_key: str, endpoint: str):
|
||||
super().__init__() # Initialize the base class to set default attributes
|
||||
if not api_key or not isinstance(api_key, str):
|
||||
raise ValueError("Invalid API key: must be a non-empty string")
|
||||
if not endpoint or not isinstance(endpoint, str):
|
||||
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.stop = [] # You can customize stop words if needed
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM or the result of a tool function call.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
ValueError: If the response format is invalid.
|
||||
"""
|
||||
# Implement your own logic to call the LLM
|
||||
# For example, using requests:
|
||||
import requests
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Convert string message to proper format if needed
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
"tools": tools
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.endpoint,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=30 # Set a reasonable timeout
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
except requests.Timeout:
|
||||
raise TimeoutError("LLM request timed out")
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
raise ValueError(f"Invalid response format: {str(e)}")
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
"""
|
||||
# Return True if your LLM supports function calling
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
"""
|
||||
# Return True if your LLM supports stop words
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size of the LLM.
|
||||
|
||||
Returns:
|
||||
The context window size as an integer.
|
||||
"""
|
||||
# Return the context window size of your LLM
|
||||
return 8192
|
||||
```
|
||||
|
||||
## Error Handling Best Practices
|
||||
|
||||
When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices:
|
||||
|
||||
### 1. Implement Try-Except Blocks for API Calls
|
||||
|
||||
Always wrap API calls in try-except blocks to handle different types of errors:
|
||||
|
||||
```python
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
try:
|
||||
# API call implementation
|
||||
response = requests.post(
|
||||
self.endpoint,
|
||||
headers=self.headers,
|
||||
json=self.prepare_payload(messages),
|
||||
timeout=30 # Set a reasonable timeout
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
except requests.Timeout:
|
||||
raise TimeoutError("LLM request timed out")
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
raise ValueError(f"Invalid response format: {str(e)}")
|
||||
```
|
||||
|
||||
### 2. Implement Retry Logic for Transient Failures
|
||||
|
||||
For transient failures like network issues or rate limiting, implement retry logic with exponential backoff:
|
||||
|
||||
```python
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
import time
|
||||
|
||||
max_retries = 3
|
||||
retry_delay = 1 # seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = requests.post(
|
||||
self.endpoint,
|
||||
headers=self.headers,
|
||||
json=self.prepare_payload(messages),
|
||||
timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
except (requests.Timeout, requests.ConnectionError) as e:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (2 ** attempt)) # Exponential backoff
|
||||
continue
|
||||
raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}")
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||
```
|
||||
|
||||
### 3. Validate Input Parameters
|
||||
|
||||
Always validate input parameters to prevent runtime errors:
|
||||
|
||||
```python
|
||||
def __init__(self, api_key: str, endpoint: str):
|
||||
super().__init__()
|
||||
if not api_key or not isinstance(api_key, str):
|
||||
raise ValueError("Invalid API key: must be a non-empty string")
|
||||
if not endpoint or not isinstance(endpoint, str):
|
||||
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
```
|
||||
|
||||
### 4. Handle Authentication Errors Gracefully
|
||||
|
||||
Provide clear error messages for authentication failures:
|
||||
|
||||
```python
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
try:
|
||||
response = requests.post(self.endpoint, headers=self.headers, json=data)
|
||||
if response.status_code == 401:
|
||||
raise ValueError("Authentication failed: Invalid API key or token")
|
||||
elif response.status_code == 403:
|
||||
raise ValueError("Authorization failed: Insufficient permissions")
|
||||
response.raise_for_status()
|
||||
# Process response
|
||||
except Exception as e:
|
||||
# Handle error
|
||||
raise
|
||||
```
|
||||
|
||||
## Example: JWT-based Authentication
|
||||
|
||||
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
|
||||
|
||||
```python
|
||||
from crewai import LLM, Agent, Task
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class JWTAuthLLM(LLM):
|
||||
def __init__(self, jwt_token: str, endpoint: str):
|
||||
super().__init__() # Initialize the base class to set default attributes
|
||||
if not jwt_token or not isinstance(jwt_token, str):
|
||||
raise ValueError("Invalid JWT token: must be a non-empty string")
|
||||
if not endpoint or not isinstance(endpoint, str):
|
||||
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
||||
self.jwt_token = jwt_token
|
||||
self.endpoint = endpoint
|
||||
self.stop = [] # You can customize stop words if needed
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with JWT authentication.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM or the result of a tool function call.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
ValueError: If the response format is invalid.
|
||||
"""
|
||||
# Implement your own logic to call the LLM with JWT authentication
|
||||
import requests
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.jwt_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Convert string message to proper format if needed
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
"tools": tools
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.endpoint,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=30 # Set a reasonable timeout
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise ValueError("Authentication failed: Invalid JWT token")
|
||||
elif response.status_code == 403:
|
||||
raise ValueError("Authorization failed: Insufficient permissions")
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
except requests.Timeout:
|
||||
raise TimeoutError("LLM request timed out")
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
raise ValueError(f"Invalid response format: {str(e)}")
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size of the LLM.
|
||||
|
||||
Returns:
|
||||
The context window size as an integer.
|
||||
"""
|
||||
return 8192
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
Here are some common issues you might encounter when implementing custom LLMs and how to resolve them:
|
||||
|
||||
### 1. Authentication Failures
|
||||
|
||||
**Symptoms**: 401 Unauthorized or 403 Forbidden errors
|
||||
|
||||
**Solutions**:
|
||||
- Verify that your API key or JWT token is valid and not expired
|
||||
- Check that you're using the correct authentication header format
|
||||
- Ensure that your token has the necessary permissions
|
||||
|
||||
### 2. Timeout Issues
|
||||
|
||||
**Symptoms**: Requests taking too long or timing out
|
||||
|
||||
**Solutions**:
|
||||
- Implement timeout handling as shown in the examples
|
||||
- Use retry logic with exponential backoff
|
||||
- Consider using a more reliable network connection
|
||||
|
||||
### 3. Response Parsing Errors
|
||||
|
||||
**Symptoms**: KeyError, IndexError, or ValueError when processing responses
|
||||
|
||||
**Solutions**:
|
||||
- Validate the response format before accessing nested fields
|
||||
- Implement proper error handling for malformed responses
|
||||
- Check the API documentation for the expected response format
|
||||
|
||||
### 4. Rate Limiting
|
||||
|
||||
**Symptoms**: 429 Too Many Requests errors
|
||||
|
||||
**Solutions**:
|
||||
- Implement rate limiting in your custom LLM
|
||||
- Add exponential backoff for retries
|
||||
- Consider using a token bucket algorithm for more precise rate control
|
||||
|
||||
## Advanced Features
|
||||
|
||||
### Logging
|
||||
|
||||
Adding logging to your custom LLM can help with debugging and monitoring:
|
||||
|
||||
```python
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class LoggingLLM(BaseLLM):
|
||||
def __init__(self, api_key: str, endpoint: str):
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.logger = logging.getLogger("crewai.llm.custom")
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages")
|
||||
try:
|
||||
# API call implementation
|
||||
response = self._make_api_call(messages, tools)
|
||||
self.logger.debug(f"LLM response received: {response[:100]}...")
|
||||
return response
|
||||
except Exception as e:
|
||||
self.logger.error(f"LLM call failed: {str(e)}")
|
||||
raise
|
||||
```
|
||||
|
||||
### Rate Limiting
|
||||
|
||||
Implementing rate limiting can help avoid overwhelming the LLM API:
|
||||
|
||||
```python
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class RateLimitedLLM(BaseLLM):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
requests_per_minute: int = 60
|
||||
):
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.request_times: List[float] = []
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
self._enforce_rate_limit()
|
||||
# Record this request time
|
||||
self.request_times.append(time.time())
|
||||
# Make the actual API call
|
||||
return self._make_api_call(messages, tools)
|
||||
|
||||
def _enforce_rate_limit(self) -> None:
|
||||
"""Enforce the rate limit by waiting if necessary."""
|
||||
now = time.time()
|
||||
# Remove request times older than 1 minute
|
||||
self.request_times = [t for t in self.request_times if now - t < 60]
|
||||
|
||||
if len(self.request_times) >= self.requests_per_minute:
|
||||
# Calculate how long to wait
|
||||
oldest_request = min(self.request_times)
|
||||
wait_time = 60 - (now - oldest_request)
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
```
|
||||
|
||||
### Metrics Collection
|
||||
|
||||
Collecting metrics can help you monitor your LLM usage:
|
||||
|
||||
```python
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class MetricsCollectingLLM(BaseLLM):
|
||||
def __init__(self, api_key: str, endpoint: str):
|
||||
super().__init__()
|
||||
self.api_key = api_key
|
||||
self.endpoint = endpoint
|
||||
self.metrics: Dict[str, Any] = {
|
||||
"total_calls": 0,
|
||||
"total_tokens": 0,
|
||||
"errors": 0,
|
||||
"latency": []
|
||||
}
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
start_time = time.time()
|
||||
self.metrics["total_calls"] += 1
|
||||
|
||||
try:
|
||||
response = self._make_api_call(messages, tools)
|
||||
# Estimate tokens (simplified)
|
||||
if isinstance(messages, str):
|
||||
token_estimate = len(messages) // 4
|
||||
else:
|
||||
token_estimate = sum(len(m.get("content", "")) // 4 for m in messages)
|
||||
self.metrics["total_tokens"] += token_estimate
|
||||
return response
|
||||
except Exception as e:
|
||||
self.metrics["errors"] += 1
|
||||
raise
|
||||
finally:
|
||||
latency = time.time() - start_time
|
||||
self.metrics["latency"].append(latency)
|
||||
|
||||
def get_metrics(self) -> Dict[str, Any]:
|
||||
"""Return the collected metrics."""
|
||||
avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0
|
||||
return {
|
||||
**self.metrics,
|
||||
"avg_latency": avg_latency
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Usage: Function Calling
|
||||
|
||||
If your LLM supports function calling, you can implement the function calling logic in your custom LLM:
|
||||
|
||||
```python
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
import requests
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.jwt_token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# Convert string message to proper format if needed
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
data = {
|
||||
"messages": messages,
|
||||
"tools": tools
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.endpoint,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
|
||||
# Check if the LLM wants to call a function
|
||||
if response_data["choices"][0]["message"].get("tool_calls"):
|
||||
tool_calls = response_data["choices"][0]["message"]["tool_calls"]
|
||||
|
||||
# Process each tool call
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
|
||||
if available_functions and function_name in available_functions:
|
||||
function_to_call = available_functions[function_name]
|
||||
function_response = function_to_call(**function_args)
|
||||
|
||||
# Add the function response to the messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call["id"],
|
||||
"name": function_name,
|
||||
"content": str(function_response)
|
||||
})
|
||||
|
||||
# Call the LLM again with the updated messages
|
||||
return self.call(messages, tools, callbacks, available_functions)
|
||||
|
||||
# Return the text response if no function call
|
||||
return response_data["choices"][0]["message"]["content"]
|
||||
except requests.Timeout:
|
||||
raise TimeoutError("LLM request timed out")
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||
except (KeyError, IndexError, ValueError) as e:
|
||||
raise ValueError(f"Invalid response format: {str(e)}")
|
||||
```
|
||||
|
||||
## Using Your Custom LLM with CrewAI
|
||||
|
||||
Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:
|
||||
|
||||
```python
|
||||
from crewai import Agent, Task, Crew
|
||||
from typing import Dict, Any
|
||||
|
||||
# Create your custom LLM instance
|
||||
jwt_llm = JWTAuthLLM(
|
||||
jwt_token="your.jwt.token",
|
||||
endpoint="https://your-llm-endpoint.com/v1/chat/completions"
|
||||
)
|
||||
|
||||
# Use it with an agent
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
goal="Find information on a topic",
|
||||
backstory="You are a research assistant tasked with finding information.",
|
||||
llm=jwt_llm,
|
||||
)
|
||||
|
||||
# Create a task for the agent
|
||||
task = Task(
|
||||
description="Research the benefits of exercise",
|
||||
agent=agent,
|
||||
expected_output="A summary of the benefits of exercise",
|
||||
)
|
||||
|
||||
# Execute the task
|
||||
result = agent.execute_task(task)
|
||||
print(result)
|
||||
|
||||
# Or use it with a crew
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
manager_llm=jwt_llm, # Use your custom LLM for the manager
|
||||
)
|
||||
|
||||
# Run the crew
|
||||
result = crew.kickoff()
|
||||
print(result)
|
||||
```
|
||||
|
||||
## Implementing Your Own Authentication Mechanism
|
||||
|
||||
The `LLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
|
||||
|
||||
- OAuth tokens
|
||||
- Client certificates
|
||||
- Custom headers
|
||||
- Session-based authentication
|
||||
- Any other authentication method required by your LLM provider
|
||||
|
||||
Simply implement the appropriate authentication logic in your custom LLM class.
|
||||
|
||||
## Migrating from BaseLLM to LLM
|
||||
|
||||
If you were previously using `BaseLLM`, you can simply replace it with `LLM`:
|
||||
|
||||
```python
|
||||
# Old code
|
||||
from crewai import BaseLLM
|
||||
|
||||
class CustomLLM(BaseLLM):
|
||||
# ...
|
||||
|
||||
# New code
|
||||
from crewai import LLM
|
||||
|
||||
class CustomLLM(LLM):
|
||||
# ...
|
||||
```
|
||||
|
||||
The `BaseLLM` class is still available for backward compatibility but will be removed in a future release. It now inherits from `LLM` and emits a deprecation warning when instantiated.
|
||||
@@ -4,7 +4,7 @@ from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm import LLM, BaseLLM, DefaultLLM
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
|
||||
@@ -21,6 +21,8 @@ __all__ = [
|
||||
"Process",
|
||||
"Task",
|
||||
"LLM",
|
||||
"BaseLLM",
|
||||
"DefaultLLM",
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
]
|
||||
|
||||
@@ -11,7 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool
|
||||
@@ -70,10 +70,10 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: Union[str, InstanceOf[LLM], Any] = Field(
|
||||
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
@@ -116,9 +116,16 @@ class Agent(BaseAgent):
|
||||
def post_init_setup(self):
|
||||
self.agent_ops_agent_name = self.role
|
||||
|
||||
self.llm = create_llm(self.llm)
|
||||
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
try:
|
||||
self.llm = create_llm(self.llm)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize LLM for agent '{self.role}': {str(e)}")
|
||||
|
||||
if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM):
|
||||
try:
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize function calling LLM for agent '{self.role}': {str(e)}")
|
||||
|
||||
if not self.agent_executor:
|
||||
self._setup_agent_executor()
|
||||
|
||||
@@ -14,7 +14,7 @@ from packaging import version
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
|
||||
print()
|
||||
|
||||
|
||||
def initialize_chat_llm(crew: Crew) -> Optional[LLM]:
|
||||
def initialize_chat_llm(crew: Crew) -> Optional[BaseLLM]:
|
||||
"""Initializes the chat LLM and handles exceptions."""
|
||||
try:
|
||||
return create_llm(crew.chat_llm)
|
||||
@@ -220,7 +220,7 @@ def get_user_input() -> str:
|
||||
|
||||
def handle_user_input(
|
||||
user_input: str,
|
||||
chat_llm: LLM,
|
||||
chat_llm: BaseLLM,
|
||||
messages: List[Dict[str, str]],
|
||||
crew_tool_schema: Dict[str, Any],
|
||||
available_functions: Dict[str, Any],
|
||||
|
||||
@@ -6,8 +6,9 @@ import warnings
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast
|
||||
|
||||
from langchain_core.tools import BaseTool as LangchainBaseTool
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -26,7 +27,7 @@ from crewai.agents.cache import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||
@@ -36,7 +37,7 @@ from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
@@ -150,14 +151,14 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: Optional[Any] = Field(
|
||||
manager_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: Optional[BaseAgent] = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
description="Language model that will be used for function calling.", default=None
|
||||
)
|
||||
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
@@ -184,7 +185,7 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="Maximum number of requests per minute for the crew execution to be respected.",
|
||||
)
|
||||
prompt_file: str = Field(
|
||||
prompt_file: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to the prompt json file to be used for the crew.",
|
||||
)
|
||||
@@ -196,7 +197,7 @@ class Crew(BaseModel):
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: Optional[Any] = Field(
|
||||
planning_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
default=None,
|
||||
description="Language model that will run the AgentPlanner if planning is True.",
|
||||
)
|
||||
@@ -212,7 +213,7 @@ class Crew(BaseModel):
|
||||
default=None,
|
||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||
)
|
||||
chat_llm: Optional[Any] = Field(
|
||||
chat_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
@@ -798,7 +799,8 @@ class Crew(BaseModel):
|
||||
|
||||
# Determine which tools to use - task tools take precedence over agent tools
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)
|
||||
# Prepare tools and ensure they're compatible with task execution
|
||||
tools_for_task = self._prepare_tools(agent_to_use, task, cast(Union[List[Tool], List[BaseTool]], tools_for_task))
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
|
||||
@@ -817,7 +819,7 @@ class Crew(BaseModel):
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
@@ -829,7 +831,7 @@ class Crew(BaseModel):
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
tools=cast(List[BaseTool], tools_for_task),
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
@@ -867,10 +869,10 @@ class Crew(BaseModel):
|
||||
return None
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: List[Tool]
|
||||
) -> List[Tool]:
|
||||
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if agent.allow_delegation:
|
||||
if hasattr(agent, "allow_delegation") and getattr(agent, "allow_delegation", False):
|
||||
if self.process == Process.hierarchical:
|
||||
if self.manager_agent:
|
||||
tools = self._update_manager_tools(task, tools)
|
||||
@@ -879,17 +881,18 @@ class Crew(BaseModel):
|
||||
"Manager agent is required for hierarchical process."
|
||||
)
|
||||
|
||||
elif agent and agent.allow_delegation:
|
||||
elif agent:
|
||||
tools = self._add_delegation_tools(task, tools)
|
||||
|
||||
# Add code execution tools if agent allows code execution
|
||||
if agent.allow_code_execution:
|
||||
if hasattr(agent, "allow_code_execution") and getattr(agent, "allow_code_execution", False):
|
||||
tools = self._add_code_execution_tools(agent, tools)
|
||||
|
||||
if agent and agent.multimodal:
|
||||
if agent and hasattr(agent, "multimodal") and getattr(agent, "multimodal", False):
|
||||
tools = self._add_multimodal_tools(agent, tools)
|
||||
|
||||
return tools
|
||||
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
|
||||
if self.process == Process.hierarchical:
|
||||
@@ -897,11 +900,11 @@ class Crew(BaseModel):
|
||||
return task.agent
|
||||
|
||||
def _merge_tools(
|
||||
self, existing_tools: List[Tool], new_tools: List[Tool]
|
||||
) -> List[Tool]:
|
||||
self, existing_tools: Union[List[Tool], List[BaseTool]], new_tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||
if not new_tools:
|
||||
return existing_tools
|
||||
return cast(List[BaseTool], existing_tools)
|
||||
|
||||
# Create mapping of tool names to new tools
|
||||
new_tool_map = {tool.name: tool for tool in new_tools}
|
||||
@@ -912,23 +915,32 @@ class Crew(BaseModel):
|
||||
# Add all new tools
|
||||
tools.extend(new_tools)
|
||||
|
||||
return tools
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _inject_delegation_tools(
|
||||
self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent]
|
||||
):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
return self._merge_tools(tools, delegation_tools)
|
||||
self, tools: Union[List[Tool], List[BaseTool]], task_agent: BaseAgent, agents: List[BaseAgent]
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
# Cast delegation_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_multimodal_tools(self, agent: BaseAgent, tools: List[Tool]):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
return self._merge_tools(tools, multimodal_tools)
|
||||
def _add_multimodal_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
if hasattr(agent, "get_multimodal_tools"):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
# Cast multimodal_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_code_execution_tools(self, agent: BaseAgent, tools: List[Tool]):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
return self._merge_tools(tools, code_tools)
|
||||
def _add_code_execution_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
if hasattr(agent, "get_code_execution_tools"):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
# Cast code_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_delegation_tools(self, task: Task, tools: List[Tool]):
|
||||
def _add_delegation_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||
if not tools:
|
||||
@@ -936,7 +948,7 @@ class Crew(BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, task.agent, agents_for_delegation
|
||||
)
|
||||
return tools
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _log_task_start(self, task: Task, role: str = "None"):
|
||||
if self.output_log_file:
|
||||
@@ -944,7 +956,7 @@ class Crew(BaseModel):
|
||||
task_name=task.name, task=task.description, agent=role, status="started"
|
||||
)
|
||||
|
||||
def _update_manager_tools(self, task: Task, tools: List[Tool]):
|
||||
def _update_manager_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
if self.manager_agent:
|
||||
if task.agent:
|
||||
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
|
||||
@@ -952,7 +964,7 @@ class Crew(BaseModel):
|
||||
tools = self._inject_delegation_tools(
|
||||
tools, self.manager_agent, self.agents
|
||||
)
|
||||
return tools
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
|
||||
context = (
|
||||
@@ -1198,21 +1210,27 @@ class Crew(BaseModel):
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
try:
|
||||
eval_llm = create_llm(eval_llm)
|
||||
if not eval_llm:
|
||||
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
|
||||
llm_instance = create_llm(eval_llm)
|
||||
if not llm_instance:
|
||||
raise ValueError("Failed to create LLM instance.")
|
||||
|
||||
# Ensure we have an LLM instance (not just BaseLLM) for CrewEvaluator
|
||||
from crewai.llm import LLM
|
||||
if not isinstance(llm_instance, LLM):
|
||||
raise TypeError("CrewEvaluator requires an LLM instance, not a BaseLLM instance.")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewTestStartedEvent(
|
||||
crew_name=self.name or "crew",
|
||||
n_iterations=n_iterations,
|
||||
eval_llm=eval_llm,
|
||||
eval_llm=llm_instance,
|
||||
inputs=inputs,
|
||||
),
|
||||
)
|
||||
test_crew = self.copy()
|
||||
evaluator = CrewEvaluator(test_crew, eval_llm) # type: ignore[arg-type]
|
||||
evaluator = CrewEvaluator(test_crew, llm_instance)
|
||||
|
||||
for i in range(1, n_iterations + 1):
|
||||
evaluator.set_iteration(i)
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||
|
||||
@@ -34,6 +35,223 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class LLM(ABC):
|
||||
"""Base class for LLM implementations.
|
||||
|
||||
This class defines the interface that all LLM implementations must follow.
|
||||
Users can extend this class to create custom LLM implementations that don't
|
||||
rely on litellm's authentication mechanism.
|
||||
|
||||
Custom LLM implementations should handle error cases gracefully, including
|
||||
timeouts, authentication failures, and malformed responses. They should also
|
||||
implement proper validation for input parameters and provide clear error
|
||||
messages when things go wrong.
|
||||
|
||||
Attributes:
|
||||
stop (list): A list of stop sequences that the LLM should use to stop generation.
|
||||
This is used by the CrewAgentExecutor and other components.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Create a new LLM instance.
|
||||
|
||||
This method handles backward compatibility by creating a DefaultLLM instance
|
||||
when the LLM class is instantiated directly with parameters.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
Either a new LLM instance or a DefaultLLM instance for backward compatibility.
|
||||
"""
|
||||
if cls is LLM and (args or kwargs.get('model') is not None):
|
||||
# Import locally to avoid circular imports
|
||||
# This is safe because DefaultLLM is defined later in this file
|
||||
DefaultLLM = globals().get('DefaultLLM')
|
||||
if DefaultLLM is None:
|
||||
# If DefaultLLM is not yet defined, return a placeholder
|
||||
# that will be replaced with a real DefaultLLM instance later
|
||||
return object.__new__(cls)
|
||||
return DefaultLLM(*args, **kwargs)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the LLM with default attributes.
|
||||
|
||||
This constructor sets default values for attributes that are expected
|
||||
by the CrewAgentExecutor and other components.
|
||||
|
||||
All custom LLM implementations should call super().__init__() to ensure
|
||||
that these default attributes are properly initialized.
|
||||
"""
|
||||
self.stop = []
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
**kwargs,
|
||||
) -> 'DefaultLLM':
|
||||
"""Create a default LLM instance using litellm.
|
||||
|
||||
This factory method creates a default LLM instance using litellm as the backend.
|
||||
It's the recommended way to create LLM instances for most users.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "gpt-4").
|
||||
timeout: Optional timeout for the LLM call.
|
||||
temperature: Optional temperature for the LLM call.
|
||||
top_p: Optional top_p for the LLM call.
|
||||
n: Optional n for the LLM call.
|
||||
stop: Optional stop sequences for the LLM call.
|
||||
max_completion_tokens: Optional max_completion_tokens for the LLM call.
|
||||
max_tokens: Optional max_tokens for the LLM call.
|
||||
presence_penalty: Optional presence_penalty for the LLM call.
|
||||
frequency_penalty: Optional frequency_penalty for the LLM call.
|
||||
logit_bias: Optional logit_bias for the LLM call.
|
||||
response_format: Optional response_format for the LLM call.
|
||||
seed: Optional seed for the LLM call.
|
||||
logprobs: Optional logprobs for the LLM call.
|
||||
top_logprobs: Optional top_logprobs for the LLM call.
|
||||
base_url: Optional base_url for the LLM call.
|
||||
api_base: Optional api_base for the LLM call.
|
||||
api_version: Optional api_version for the LLM call.
|
||||
api_key: Optional api_key for the LLM call.
|
||||
callbacks: Optional callbacks for the LLM call.
|
||||
reasoning_effort: Optional reasoning_effort for the LLM call.
|
||||
**kwargs: Additional keyword arguments for the LLM call.
|
||||
|
||||
Returns:
|
||||
A DefaultLLM instance configured with the provided parameters.
|
||||
"""
|
||||
from crewai.llm import DefaultLLM
|
||||
|
||||
return DefaultLLM(
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
callbacks=callbacks,
|
||||
reasoning_effort=reasoning_effort,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
Can be a string or list of message dictionaries.
|
||||
If string, it will be converted to a single user message.
|
||||
If list, each dict must have 'role' and 'content' keys.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
Each tool should define its name, description, and parameters.
|
||||
callbacks: Optional list of callback functions to be executed
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
|
||||
Raises:
|
||||
ValueError: If the messages format is invalid.
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement call()")
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
This method should return True if the LLM implementation supports
|
||||
function calling (tools), and False otherwise. If this method returns
|
||||
True, the LLM should be able to handle the 'tools' parameter in the
|
||||
call() method.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement supports_function_calling()")
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
This method should return True if the LLM implementation supports
|
||||
stop words, and False otherwise. If this method returns True, the
|
||||
LLM should respect the 'stop' attribute when generating responses.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement supports_stop_words()")
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size of the LLM.
|
||||
|
||||
This method should return the maximum number of tokens that the LLM
|
||||
can process in a single request. This is used by CrewAI to ensure
|
||||
that messages don't exceed the LLM's context window.
|
||||
|
||||
Returns:
|
||||
The context window size as an integer.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_context_window_size()")
|
||||
|
||||
|
||||
class FilteredStream:
|
||||
def __init__(self, original_stream):
|
||||
self._original_stream = original_stream
|
||||
@@ -126,7 +344,14 @@ def suppress_warnings():
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
class LLM:
|
||||
class DefaultLLM(LLM):
|
||||
"""Default LLM implementation using litellm.
|
||||
|
||||
This class provides a concrete implementation of the LLM interface
|
||||
using litellm as the backend. It's the default implementation used
|
||||
by CrewAI when no custom LLM is provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@@ -152,6 +377,8 @@ class LLM:
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__() # Initialize the base class
|
||||
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
@@ -180,7 +407,7 @@ class LLM:
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
self.stop = [] # Already initialized in base class
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
@@ -564,3 +791,27 @@ class LLM:
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
|
||||
class BaseLLM(LLM):
|
||||
"""Deprecated: Use LLM instead.
|
||||
|
||||
This class is kept for backward compatibility and will be removed in a future release.
|
||||
It inherits from LLM and provides the same interface, but emits a deprecation warning
|
||||
when instantiated.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the BaseLLM with a deprecation warning.
|
||||
|
||||
This constructor emits a deprecation warning and then calls the parent class's
|
||||
constructor to initialize the LLM.
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"BaseLLM is deprecated and will be removed in a future release. "
|
||||
"Use LLM instead for custom implementations.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
|
||||
|
||||
def create_llm(
|
||||
@@ -19,17 +19,17 @@ def create_llm(
|
||||
- None: Use environment-based or fallback default model.
|
||||
|
||||
Returns:
|
||||
An LLM instance if successful, or None if something fails.
|
||||
A LLM instance if successful, or None if something fails.
|
||||
"""
|
||||
|
||||
# 1) If llm_value is already an LLM object, return it directly
|
||||
# 1) If llm_value is already a LLM object, return it directly
|
||||
if isinstance(llm_value, LLM):
|
||||
return llm_value
|
||||
|
||||
# 2) If llm_value is a string (model name)
|
||||
if isinstance(llm_value, str):
|
||||
try:
|
||||
created_llm = LLM(model=llm_value)
|
||||
created_llm = LLM.create(model=llm_value)
|
||||
return created_llm
|
||||
except Exception as e:
|
||||
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
|
||||
@@ -56,7 +56,7 @@ def create_llm(
|
||||
base_url: Optional[str] = getattr(llm_value, "base_url", None)
|
||||
api_base: Optional[str] = getattr(llm_value, "api_base", None)
|
||||
|
||||
created_llm = LLM(
|
||||
created_llm = LLM.create(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
@@ -175,7 +175,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
|
||||
# Try creating the LLM
|
||||
try:
|
||||
new_llm = LLM(**llm_params)
|
||||
new_llm = LLM.create(**llm_params)
|
||||
return new_llm
|
||||
except Exception as e:
|
||||
print(
|
||||
|
||||
89
tests/cassettes/test_litellm_auth_error_handling.yaml
Normal file
89
tests/cassettes/test_litellm_auth_error_handling.yaml
Normal file
@@ -0,0 +1,89 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages": [{"role": "system", "content": "You are test role. test backstory\nYour
|
||||
personal goal is: test goal\nTo give my best complete final answer to the task
|
||||
respond using the exact following format:\n\nThought: I now can give a great
|
||||
answer\nFinal Answer: Your final answer must be the great and the most complete
|
||||
as possible, it must be outcome described.\n\nI MUST use these formats, my job
|
||||
depends on it!"}, {"role": "user", "content": "\nCurrent Task: Test task\n\nThis
|
||||
is the expected criteria for your final answer: Test output\nyou MUST return
|
||||
the actual complete content as the final answer, not a summary.\n\nBegin! This
|
||||
is VERY important to you, use the tools available and give your best Final Answer,
|
||||
your job depends on it!\n\nThought:"}], "model": "gpt-4", "stop": ["\nObservation:"]}'
|
||||
headers:
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- gzip, deflate
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '805'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
- _cfuvid=xecEkmr_qTiKn7EKC7aeGN5bpsbPM9ofyIsipL4VCYM-1734033219265-0.0.1.1-604800000
|
||||
host:
|
||||
- api.openai.com
|
||||
user-agent:
|
||||
- OpenAI/Python 1.61.0
|
||||
x-stainless-arch:
|
||||
- x64
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- Linux
|
||||
x-stainless-package-version:
|
||||
- 1.61.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.12.7
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
content: "{\n \"error\": {\n \"message\": \"Incorrect API key provided:
|
||||
sk-proj-********************************************************************************************************************************************************sLcA.
|
||||
You can find your API key at https://platform.openai.com/account/api-keys.\",\n
|
||||
\ \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\":
|
||||
\"invalid_api_key\"\n }\n}\n"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- 9201beec18a0762e-SEA
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Length:
|
||||
- '414'
|
||||
Content-Type:
|
||||
- application/json; charset=utf-8
|
||||
Date:
|
||||
- Fri, 14 Mar 2025 06:34:31 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Set-Cookie:
|
||||
- __cf_bm=wF6OyTyATDK7A9tGqAdaSB3QZfmd34JWPicYlDC1hug-1741934071-1.0.1.1-nZThPWX_7A9FsU7Z14PyrVhl6mCD99iuk9ujCFkNCCdepMHEwK9EXoDrP4IBBCXxkXmKjrVTSaQ63zpcociXuMHR8JKhth2fRUV2H4hMldY;
|
||||
path=/; expires=Fri, 14-Mar-25 07:04:31 GMT; domain=.api.openai.com; HttpOnly;
|
||||
Secure; SameSite=None
|
||||
- _cfuvid=rn5IWZdYMRmbyCa2_84MkWO46MIaP6soWc8npaLc9iQ-1741934071787-0.0.1.1-604800000;
|
||||
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||
X-Content-Type-Options:
|
||||
- nosniff
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
- DYNAMIC
|
||||
strict-transport-security:
|
||||
- max-age=31536000; includeSubDomains; preload
|
||||
vary:
|
||||
- Origin
|
||||
x-request-id:
|
||||
- req_f55471c8eb5755daaef3d63eab5a95de
|
||||
http_version: HTTP/1.1
|
||||
status_code: 401
|
||||
version: 1
|
||||
570
tests/custom_llm_test.py
Normal file
570
tests/custom_llm_test.py
Normal file
@@ -0,0 +1,570 @@
|
||||
from collections import deque
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import time
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
class CustomLLM(LLM):
|
||||
"""Custom LLM implementation for testing.
|
||||
|
||||
This is a simple implementation of the LLM abstract base class
|
||||
that returns a predefined response for testing purposes.
|
||||
"""
|
||||
|
||||
def __init__(self, response: str = "Custom LLM response"):
|
||||
"""Initialize the CustomLLM with a predefined response.
|
||||
|
||||
Args:
|
||||
response: The predefined response to return from call().
|
||||
"""
|
||||
super().__init__()
|
||||
self.response = response
|
||||
self.calls = []
|
||||
self.stop = []
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Record the call and return the predefined response.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
The predefined response string.
|
||||
"""
|
||||
self.calls.append({
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions
|
||||
})
|
||||
return self.response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Return True to indicate that function calling is supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports function calling.
|
||||
"""
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Return True to indicate that stop words are supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports stop words.
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Return a default context window size.
|
||||
|
||||
Returns:
|
||||
8192, a typical context window size for modern LLMs.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
|
||||
def test_custom_llm_implementation():
|
||||
"""Test that a custom LLM implementation works with create_llm."""
|
||||
custom_llm = CustomLLM(response="The answer is 42")
|
||||
|
||||
# Test that create_llm returns the custom LLM instance directly
|
||||
result_llm = create_llm(custom_llm)
|
||||
|
||||
assert result_llm is custom_llm
|
||||
|
||||
# Test calling the custom LLM
|
||||
response = result_llm.call("What is the answer to life, the universe, and everything?")
|
||||
|
||||
# Verify that the custom LLM was called
|
||||
assert len(custom_llm.calls) > 0
|
||||
# Verify that the response from the custom LLM was used
|
||||
assert response == "The answer is 42"
|
||||
|
||||
|
||||
class JWTAuthLLM(LLM):
|
||||
"""Custom LLM implementation with JWT authentication.
|
||||
|
||||
This class demonstrates how to implement a custom LLM that uses JWT
|
||||
authentication instead of API key-based authentication. It validates
|
||||
the JWT token before each call and checks for token expiration.
|
||||
"""
|
||||
|
||||
def __init__(self, jwt_token: str, expiration_buffer: int = 60):
|
||||
"""Initialize the JWTAuthLLM with a JWT token.
|
||||
|
||||
Args:
|
||||
jwt_token: The JWT token to use for authentication.
|
||||
expiration_buffer: Buffer time in seconds to warn about token expiration.
|
||||
Default is 60 seconds.
|
||||
|
||||
Raises:
|
||||
ValueError: If the JWT token is invalid or missing.
|
||||
"""
|
||||
super().__init__()
|
||||
if not jwt_token or not isinstance(jwt_token, str):
|
||||
raise ValueError("Invalid JWT token")
|
||||
|
||||
self.jwt_token = jwt_token
|
||||
self.expiration_buffer = expiration_buffer
|
||||
self.calls = []
|
||||
self.stop = []
|
||||
|
||||
# Validate the token immediately
|
||||
self._validate_token()
|
||||
|
||||
def _validate_token(self) -> None:
|
||||
"""Validate the JWT token.
|
||||
|
||||
Checks if the token is valid and not expired. Also warns if the token
|
||||
is about to expire within the expiration_buffer time.
|
||||
|
||||
Raises:
|
||||
ValueError: If the token is invalid, expired, or malformed.
|
||||
"""
|
||||
try:
|
||||
# Decode without verification to check expiration
|
||||
# In a real implementation, you would verify the signature
|
||||
decoded = jwt.decode(self.jwt_token, options={"verify_signature": False})
|
||||
|
||||
# Check if token is expired or about to expire
|
||||
if 'exp' in decoded:
|
||||
expiration_time = decoded['exp']
|
||||
current_time = time.time()
|
||||
|
||||
if expiration_time < current_time:
|
||||
raise ValueError("JWT token has expired")
|
||||
|
||||
if expiration_time < current_time + self.expiration_buffer:
|
||||
# Token will expire soon, log a warning
|
||||
import logging
|
||||
logging.warning(f"JWT token will expire in {expiration_time - current_time} seconds")
|
||||
except jwt.PyJWTError as e:
|
||||
raise ValueError(f"Invalid JWT token format: {str(e)}")
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with JWT authentication.
|
||||
|
||||
Validates the JWT token before making the call to ensure it's still valid.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
The LLM response.
|
||||
|
||||
Raises:
|
||||
ValueError: If the JWT token is invalid or expired.
|
||||
TimeoutError: If the request times out.
|
||||
ConnectionError: If there's a connection issue.
|
||||
"""
|
||||
# Validate token before making the call
|
||||
self._validate_token()
|
||||
|
||||
self.calls.append({
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions
|
||||
})
|
||||
|
||||
# In a real implementation, this would use the JWT token to authenticate
|
||||
# with an external service
|
||||
return "Response from JWT-authenticated LLM"
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Return True to indicate that function calling is supported."""
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Return True to indicate that stop words are supported."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Return a default context window size."""
|
||||
return 8192
|
||||
|
||||
|
||||
def test_custom_llm_with_jwt_auth():
|
||||
"""Test a custom LLM implementation with JWT authentication."""
|
||||
# Create a valid JWT token that expires 1 hour from now
|
||||
valid_token = jwt.encode(
|
||||
{"exp": int(time.time()) + 3600},
|
||||
"secret",
|
||||
algorithm="HS256"
|
||||
)
|
||||
|
||||
jwt_llm = JWTAuthLLM(jwt_token=valid_token)
|
||||
|
||||
# Test that create_llm returns the JWT-authenticated LLM instance directly
|
||||
result_llm = create_llm(jwt_llm)
|
||||
|
||||
assert result_llm is jwt_llm
|
||||
|
||||
# Test calling the JWT-authenticated LLM
|
||||
response = result_llm.call("Test message")
|
||||
|
||||
# Verify that the JWT-authenticated LLM was called
|
||||
assert len(jwt_llm.calls) > 0
|
||||
# Verify that the response from the JWT-authenticated LLM was used
|
||||
assert response == "Response from JWT-authenticated LLM"
|
||||
|
||||
|
||||
def test_jwt_auth_llm_validation():
|
||||
"""Test that JWT token validation works correctly."""
|
||||
# Test with invalid JWT token (empty string)
|
||||
with pytest.raises(ValueError, match="Invalid JWT token"):
|
||||
JWTAuthLLM(jwt_token="")
|
||||
|
||||
# Test with invalid JWT token (non-string)
|
||||
with pytest.raises(ValueError, match="Invalid JWT token"):
|
||||
JWTAuthLLM(jwt_token=None)
|
||||
|
||||
# Test with expired token
|
||||
# Create a token that expired 1 hour ago
|
||||
expired_token = jwt.encode(
|
||||
{"exp": int(time.time()) - 3600},
|
||||
"secret",
|
||||
algorithm="HS256"
|
||||
)
|
||||
with pytest.raises(ValueError, match="JWT token has expired"):
|
||||
JWTAuthLLM(jwt_token=expired_token)
|
||||
|
||||
# Test with malformed token
|
||||
with pytest.raises(ValueError, match="Invalid JWT token format"):
|
||||
JWTAuthLLM(jwt_token="not.a.valid.jwt.token")
|
||||
|
||||
# Test with valid token
|
||||
# Create a token that expires 1 hour from now
|
||||
valid_token = jwt.encode(
|
||||
{"exp": int(time.time()) + 3600},
|
||||
"secret",
|
||||
algorithm="HS256"
|
||||
)
|
||||
# This should not raise an exception
|
||||
jwt_llm = JWTAuthLLM(jwt_token=valid_token)
|
||||
assert jwt_llm.jwt_token == valid_token
|
||||
|
||||
|
||||
class TimeoutHandlingLLM(LLM):
|
||||
"""Custom LLM implementation with timeout handling and retry logic."""
|
||||
|
||||
def __init__(self, max_retries: int = 3, timeout: int = 30):
|
||||
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts.
|
||||
timeout: Timeout in seconds for each API call.
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_retries = max_retries
|
||||
self.timeout = timeout
|
||||
self.calls = []
|
||||
self.stop = []
|
||||
self.fail_count = 0 # Number of times to simulate failure
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Simulate API calls with timeout handling and retry logic.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
A response string based on whether this is the first attempt or a retry.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If all retry attempts fail.
|
||||
"""
|
||||
# Record the initial call
|
||||
self.calls.append({
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions,
|
||||
"attempt": 0
|
||||
})
|
||||
|
||||
# Simulate retry logic
|
||||
for attempt in range(self.max_retries):
|
||||
# Skip the first attempt recording since we already did that above
|
||||
if attempt == 0:
|
||||
# Simulate a failure if fail_count > 0
|
||||
if self.fail_count > 0:
|
||||
self.fail_count -= 1
|
||||
# If we've used all retries, raise an error
|
||||
if attempt == self.max_retries - 1:
|
||||
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
|
||||
# Otherwise, continue to the next attempt (simulating backoff)
|
||||
continue
|
||||
else:
|
||||
# Success on first attempt
|
||||
return "First attempt response"
|
||||
else:
|
||||
# This is a retry attempt (attempt > 0)
|
||||
# Always record retry attempts
|
||||
self.calls.append({
|
||||
"retry_attempt": attempt,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions
|
||||
})
|
||||
|
||||
# Simulate a failure if fail_count > 0
|
||||
if self.fail_count > 0:
|
||||
self.fail_count -= 1
|
||||
# If we've used all retries, raise an error
|
||||
if attempt == self.max_retries - 1:
|
||||
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
|
||||
# Otherwise, continue to the next attempt (simulating backoff)
|
||||
continue
|
||||
else:
|
||||
# Success on retry
|
||||
return "Response after retry"
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Return True to indicate that function calling is supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports function calling.
|
||||
"""
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Return True to indicate that stop words are supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports stop words.
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Return a default context window size.
|
||||
|
||||
Returns:
|
||||
8192, a typical context window size for modern LLMs.
|
||||
"""
|
||||
return 8192
|
||||
|
||||
|
||||
def test_timeout_handling_llm():
|
||||
"""Test a custom LLM implementation with timeout handling and retry logic."""
|
||||
# Test successful first attempt
|
||||
llm = TimeoutHandlingLLM()
|
||||
response = llm.call("Test message")
|
||||
assert response == "First attempt response"
|
||||
assert len(llm.calls) == 1
|
||||
|
||||
# Test successful retry
|
||||
llm = TimeoutHandlingLLM()
|
||||
llm.fail_count = 1 # Fail once, then succeed
|
||||
response = llm.call("Test message")
|
||||
assert response == "Response after retry"
|
||||
assert len(llm.calls) == 2 # Initial call + successful retry call
|
||||
|
||||
# Test failure after all retries
|
||||
llm = TimeoutHandlingLLM(max_retries=2)
|
||||
llm.fail_count = 2 # Fail twice, which is all retries
|
||||
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
|
||||
llm.call("Test message")
|
||||
assert len(llm.calls) == 2 # Initial call + failed retry attempt
|
||||
|
||||
|
||||
def test_rate_limited_llm():
|
||||
"""Test that rate limiting works correctly."""
|
||||
# Create a rate limited LLM with a very low limit (2 requests per minute)
|
||||
llm = RateLimitedLLM(requests_per_minute=2)
|
||||
|
||||
# First request should succeed
|
||||
response1 = llm.call("Test message 1")
|
||||
assert response1 == "Rate limited response"
|
||||
assert len(llm.calls) == 1
|
||||
|
||||
# Second request should succeed
|
||||
response2 = llm.call("Test message 2")
|
||||
assert response2 == "Rate limited response"
|
||||
assert len(llm.calls) == 2
|
||||
|
||||
# Third request should fail due to rate limiting
|
||||
with pytest.raises(ValueError, match="Rate limit exceeded"):
|
||||
llm.call("Test message 3")
|
||||
|
||||
# Test with invalid requests_per_minute
|
||||
with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"):
|
||||
RateLimitedLLM(requests_per_minute=0)
|
||||
|
||||
with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"):
|
||||
RateLimitedLLM(requests_per_minute=-1)
|
||||
|
||||
|
||||
def test_rate_limit_reset():
|
||||
"""Test that rate limits reset after the time window passes."""
|
||||
# Create a rate limited LLM with a very low limit (1 request per minute)
|
||||
# and a short time window for testing (1 second instead of 60 seconds)
|
||||
time_window = 1 # 1 second instead of 60 seconds
|
||||
llm = RateLimitedLLM(requests_per_minute=1, time_window=time_window)
|
||||
|
||||
# First request should succeed
|
||||
response1 = llm.call("Test message 1")
|
||||
assert response1 == "Rate limited response"
|
||||
|
||||
# Second request should fail due to rate limiting
|
||||
with pytest.raises(ValueError, match="Rate limit exceeded"):
|
||||
llm.call("Test message 2")
|
||||
|
||||
# Wait for the rate limit to reset
|
||||
import time
|
||||
time.sleep(time_window + 0.1) # Add a small buffer
|
||||
|
||||
# After waiting, we should be able to make another request
|
||||
response3 = llm.call("Test message 3")
|
||||
assert response3 == "Rate limited response"
|
||||
assert len(llm.calls) == 2 # First and third requests
|
||||
|
||||
|
||||
class RateLimitedLLM(LLM):
|
||||
"""Custom LLM implementation with rate limiting.
|
||||
|
||||
This class demonstrates how to implement a custom LLM with rate limiting
|
||||
capabilities. It uses a sliding window algorithm to ensure that no more
|
||||
than a specified number of requests are made within a given time period.
|
||||
"""
|
||||
|
||||
def __init__(self, requests_per_minute: int = 60, base_response: str = "Rate limited response", time_window: int = 60):
|
||||
"""Initialize the RateLimitedLLM with rate limiting parameters.
|
||||
|
||||
Args:
|
||||
requests_per_minute: Maximum number of requests allowed per minute.
|
||||
base_response: Default response to return.
|
||||
time_window: Time window in seconds for rate limiting (default: 60).
|
||||
This is configurable for testing purposes.
|
||||
|
||||
Raises:
|
||||
ValueError: If requests_per_minute is not a positive integer.
|
||||
"""
|
||||
super().__init__()
|
||||
if not isinstance(requests_per_minute, int) or requests_per_minute <= 0:
|
||||
raise ValueError("requests_per_minute must be a positive integer")
|
||||
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.base_response = base_response
|
||||
self.time_window = time_window
|
||||
self.request_times = deque()
|
||||
self.calls = []
|
||||
self.stop = []
|
||||
|
||||
def _check_rate_limit(self) -> None:
|
||||
"""Check if the current request exceeds the rate limit.
|
||||
|
||||
This method implements a sliding window rate limiting algorithm.
|
||||
It keeps track of request timestamps and ensures that no more than
|
||||
`requests_per_minute` requests are made within the configured time window.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rate limit is exceeded.
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Remove requests older than the time window
|
||||
while self.request_times and current_time - self.request_times[0] > self.time_window:
|
||||
self.request_times.popleft()
|
||||
|
||||
# Check if we've exceeded the rate limit
|
||||
if len(self.request_times) >= self.requests_per_minute:
|
||||
wait_time = self.time_window - (current_time - self.request_times[0])
|
||||
raise ValueError(
|
||||
f"Rate limit exceeded. Maximum {self.requests_per_minute} "
|
||||
f"requests per {self.time_window} seconds. Try again in {wait_time:.2f} seconds."
|
||||
)
|
||||
|
||||
# Record this request
|
||||
self.request_times.append(current_time)
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with rate limiting.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
The LLM response.
|
||||
|
||||
Raises:
|
||||
ValueError: If the rate limit is exceeded.
|
||||
"""
|
||||
# Check rate limit before making the call
|
||||
self._check_rate_limit()
|
||||
|
||||
self.calls.append({
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions
|
||||
})
|
||||
|
||||
return self.base_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Return True to indicate that function calling is supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports function calling.
|
||||
"""
|
||||
return True
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Return True to indicate that stop words are supported.
|
||||
|
||||
Returns:
|
||||
True, indicating that this LLM supports stop words.
|
||||
"""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Return a default context window size.
|
||||
|
||||
Returns:
|
||||
8192, a typical context window size for modern LLMs.
|
||||
"""
|
||||
return 8192
|
||||
Reference in New Issue
Block a user