mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-04 13:48:31 +00:00
Some checks are pending
Notify Downstream / notify-downstream (push) Waiting to run
- Add v0.114.0 release notes with highlights image and doc links - Restructure docs navigation (Strategy group, Releases tab, navbar links) - Update quickstart with deployment video and clearer instructions - Add/rename guides (Custom Manager Agent, Custom LLM) - Remove legacy concept/tool docs - Add new images and tool docs - Minor formatting and content improvements throughout
647 lines
22 KiB
Plaintext
647 lines
22 KiB
Plaintext
---
|
|
title: Custom LLM Implementation
|
|
description: Learn how to create custom LLM implementations in CrewAI.
|
|
icon: code
|
|
---
|
|
|
|
## Custom LLM Implementations
|
|
|
|
CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
|
|
|
|
To create a custom LLM implementation, you need to:
|
|
|
|
1. Inherit from the `BaseLLM` abstract base class
|
|
2. Implement the required methods:
|
|
- `call()`: The main method to call the LLM with messages
|
|
- `supports_function_calling()`: Whether the LLM supports function calling
|
|
- `supports_stop_words()`: Whether the LLM supports stop words
|
|
- `get_context_window_size()`: The context window size of the LLM
|
|
|
|
## Example: Basic Custom LLM
|
|
|
|
```python
|
|
from crewai import BaseLLM
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
class CustomLLM(BaseLLM):
|
|
def __init__(self, api_key: str, endpoint: str):
|
|
super().__init__() # Initialize the base class to set default attributes
|
|
if not api_key or not isinstance(api_key, str):
|
|
raise ValueError("Invalid API key: must be a non-empty string")
|
|
if not endpoint or not isinstance(endpoint, str):
|
|
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
|
self.api_key = api_key
|
|
self.endpoint = endpoint
|
|
self.stop = [] # You can customize stop words if needed
|
|
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
"""Call the LLM with the given messages.
|
|
|
|
Args:
|
|
messages: Input messages for the LLM.
|
|
tools: Optional list of tool schemas for function calling.
|
|
callbacks: Optional list of callback functions.
|
|
available_functions: Optional dict mapping function names to callables.
|
|
|
|
Returns:
|
|
Either a text response from the LLM or the result of a tool function call.
|
|
|
|
Raises:
|
|
TimeoutError: If the LLM request times out.
|
|
RuntimeError: If the LLM request fails for other reasons.
|
|
ValueError: If the response format is invalid.
|
|
"""
|
|
# Implement your own logic to call the LLM
|
|
# For example, using requests:
|
|
import requests
|
|
|
|
try:
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
# Convert string message to proper format if needed
|
|
if isinstance(messages, str):
|
|
messages = [{"role": "user", "content": messages}]
|
|
|
|
data = {
|
|
"messages": messages,
|
|
"tools": tools
|
|
}
|
|
|
|
response = requests.post(
|
|
self.endpoint,
|
|
headers=headers,
|
|
json=data,
|
|
timeout=30 # Set a reasonable timeout
|
|
)
|
|
response.raise_for_status() # Raise an exception for HTTP errors
|
|
return response.json()["choices"][0]["message"]["content"]
|
|
except requests.Timeout:
|
|
raise TimeoutError("LLM request timed out")
|
|
except requests.RequestException as e:
|
|
raise RuntimeError(f"LLM request failed: {str(e)}")
|
|
except (KeyError, IndexError, ValueError) as e:
|
|
raise ValueError(f"Invalid response format: {str(e)}")
|
|
|
|
def supports_function_calling(self) -> bool:
|
|
"""Check if the LLM supports function calling.
|
|
|
|
Returns:
|
|
True if the LLM supports function calling, False otherwise.
|
|
"""
|
|
# Return True if your LLM supports function calling
|
|
return True
|
|
|
|
def supports_stop_words(self) -> bool:
|
|
"""Check if the LLM supports stop words.
|
|
|
|
Returns:
|
|
True if the LLM supports stop words, False otherwise.
|
|
"""
|
|
# Return True if your LLM supports stop words
|
|
return True
|
|
|
|
def get_context_window_size(self) -> int:
|
|
"""Get the context window size of the LLM.
|
|
|
|
Returns:
|
|
The context window size as an integer.
|
|
"""
|
|
# Return the context window size of your LLM
|
|
return 8192
|
|
```
|
|
|
|
## Error Handling Best Practices
|
|
|
|
When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices:
|
|
|
|
### 1. Implement Try-Except Blocks for API Calls
|
|
|
|
Always wrap API calls in try-except blocks to handle different types of errors:
|
|
|
|
```python
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
try:
|
|
# API call implementation
|
|
response = requests.post(
|
|
self.endpoint,
|
|
headers=self.headers,
|
|
json=self.prepare_payload(messages),
|
|
timeout=30 # Set a reasonable timeout
|
|
)
|
|
response.raise_for_status() # Raise an exception for HTTP errors
|
|
return response.json()["choices"][0]["message"]["content"]
|
|
except requests.Timeout:
|
|
raise TimeoutError("LLM request timed out")
|
|
except requests.RequestException as e:
|
|
raise RuntimeError(f"LLM request failed: {str(e)}")
|
|
except (KeyError, IndexError, ValueError) as e:
|
|
raise ValueError(f"Invalid response format: {str(e)}")
|
|
```
|
|
|
|
### 2. Implement Retry Logic for Transient Failures
|
|
|
|
For transient failures like network issues or rate limiting, implement retry logic with exponential backoff:
|
|
|
|
```python
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
import time
|
|
|
|
max_retries = 3
|
|
retry_delay = 1 # seconds
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
response = requests.post(
|
|
self.endpoint,
|
|
headers=self.headers,
|
|
json=self.prepare_payload(messages),
|
|
timeout=30
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()["choices"][0]["message"]["content"]
|
|
except (requests.Timeout, requests.ConnectionError) as e:
|
|
if attempt < max_retries - 1:
|
|
time.sleep(retry_delay * (2 ** attempt)) # Exponential backoff
|
|
continue
|
|
raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}")
|
|
except requests.RequestException as e:
|
|
raise RuntimeError(f"LLM request failed: {str(e)}")
|
|
```
|
|
|
|
### 3. Validate Input Parameters
|
|
|
|
Always validate input parameters to prevent runtime errors:
|
|
|
|
```python
|
|
def __init__(self, api_key: str, endpoint: str):
|
|
super().__init__()
|
|
if not api_key or not isinstance(api_key, str):
|
|
raise ValueError("Invalid API key: must be a non-empty string")
|
|
if not endpoint or not isinstance(endpoint, str):
|
|
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
|
self.api_key = api_key
|
|
self.endpoint = endpoint
|
|
```
|
|
|
|
### 4. Handle Authentication Errors Gracefully
|
|
|
|
Provide clear error messages for authentication failures:
|
|
|
|
```python
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
try:
|
|
response = requests.post(self.endpoint, headers=self.headers, json=data)
|
|
if response.status_code == 401:
|
|
raise ValueError("Authentication failed: Invalid API key or token")
|
|
elif response.status_code == 403:
|
|
raise ValueError("Authorization failed: Insufficient permissions")
|
|
response.raise_for_status()
|
|
# Process response
|
|
except Exception as e:
|
|
# Handle error
|
|
raise
|
|
```
|
|
|
|
## Example: JWT-based Authentication
|
|
|
|
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
|
|
|
|
```python
|
|
from crewai import BaseLLM, Agent, Task
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
class JWTAuthLLM(BaseLLM):
|
|
def __init__(self, jwt_token: str, endpoint: str):
|
|
super().__init__() # Initialize the base class to set default attributes
|
|
if not jwt_token or not isinstance(jwt_token, str):
|
|
raise ValueError("Invalid JWT token: must be a non-empty string")
|
|
if not endpoint or not isinstance(endpoint, str):
|
|
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
|
self.jwt_token = jwt_token
|
|
self.endpoint = endpoint
|
|
self.stop = [] # You can customize stop words if needed
|
|
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
"""Call the LLM with JWT authentication.
|
|
|
|
Args:
|
|
messages: Input messages for the LLM.
|
|
tools: Optional list of tool schemas for function calling.
|
|
callbacks: Optional list of callback functions.
|
|
available_functions: Optional dict mapping function names to callables.
|
|
|
|
Returns:
|
|
Either a text response from the LLM or the result of a tool function call.
|
|
|
|
Raises:
|
|
TimeoutError: If the LLM request times out.
|
|
RuntimeError: If the LLM request fails for other reasons.
|
|
ValueError: If the response format is invalid.
|
|
"""
|
|
# Implement your own logic to call the LLM with JWT authentication
|
|
import requests
|
|
|
|
try:
|
|
headers = {
|
|
"Authorization": f"Bearer {self.jwt_token}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
# Convert string message to proper format if needed
|
|
if isinstance(messages, str):
|
|
messages = [{"role": "user", "content": messages}]
|
|
|
|
data = {
|
|
"messages": messages,
|
|
"tools": tools
|
|
}
|
|
|
|
response = requests.post(
|
|
self.endpoint,
|
|
headers=headers,
|
|
json=data,
|
|
timeout=30 # Set a reasonable timeout
|
|
)
|
|
|
|
if response.status_code == 401:
|
|
raise ValueError("Authentication failed: Invalid JWT token")
|
|
elif response.status_code == 403:
|
|
raise ValueError("Authorization failed: Insufficient permissions")
|
|
|
|
response.raise_for_status() # Raise an exception for HTTP errors
|
|
return response.json()["choices"][0]["message"]["content"]
|
|
except requests.Timeout:
|
|
raise TimeoutError("LLM request timed out")
|
|
except requests.RequestException as e:
|
|
raise RuntimeError(f"LLM request failed: {str(e)}")
|
|
except (KeyError, IndexError, ValueError) as e:
|
|
raise ValueError(f"Invalid response format: {str(e)}")
|
|
|
|
def supports_function_calling(self) -> bool:
|
|
"""Check if the LLM supports function calling.
|
|
|
|
Returns:
|
|
True if the LLM supports function calling, False otherwise.
|
|
"""
|
|
return True
|
|
|
|
def supports_stop_words(self) -> bool:
|
|
"""Check if the LLM supports stop words.
|
|
|
|
Returns:
|
|
True if the LLM supports stop words, False otherwise.
|
|
"""
|
|
return True
|
|
|
|
def get_context_window_size(self) -> int:
|
|
"""Get the context window size of the LLM.
|
|
|
|
Returns:
|
|
The context window size as an integer.
|
|
"""
|
|
return 8192
|
|
```
|
|
|
|
## Troubleshooting
|
|
|
|
Here are some common issues you might encounter when implementing custom LLMs and how to resolve them:
|
|
|
|
### 1. Authentication Failures
|
|
|
|
**Symptoms**: 401 Unauthorized or 403 Forbidden errors
|
|
|
|
**Solutions**:
|
|
- Verify that your API key or JWT token is valid and not expired
|
|
- Check that you're using the correct authentication header format
|
|
- Ensure that your token has the necessary permissions
|
|
|
|
### 2. Timeout Issues
|
|
|
|
**Symptoms**: Requests taking too long or timing out
|
|
|
|
**Solutions**:
|
|
- Implement timeout handling as shown in the examples
|
|
- Use retry logic with exponential backoff
|
|
- Consider using a more reliable network connection
|
|
|
|
### 3. Response Parsing Errors
|
|
|
|
**Symptoms**: KeyError, IndexError, or ValueError when processing responses
|
|
|
|
**Solutions**:
|
|
- Validate the response format before accessing nested fields
|
|
- Implement proper error handling for malformed responses
|
|
- Check the API documentation for the expected response format
|
|
|
|
### 4. Rate Limiting
|
|
|
|
**Symptoms**: 429 Too Many Requests errors
|
|
|
|
**Solutions**:
|
|
- Implement rate limiting in your custom LLM
|
|
- Add exponential backoff for retries
|
|
- Consider using a token bucket algorithm for more precise rate control
|
|
|
|
## Advanced Features
|
|
|
|
### Logging
|
|
|
|
Adding logging to your custom LLM can help with debugging and monitoring:
|
|
|
|
```python
|
|
import logging
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
class LoggingLLM(BaseLLM):
|
|
def __init__(self, api_key: str, endpoint: str):
|
|
super().__init__()
|
|
self.api_key = api_key
|
|
self.endpoint = endpoint
|
|
self.logger = logging.getLogger("crewai.llm.custom")
|
|
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages")
|
|
try:
|
|
# API call implementation
|
|
response = self._make_api_call(messages, tools)
|
|
self.logger.debug(f"LLM response received: {response[:100]}...")
|
|
return response
|
|
except Exception as e:
|
|
self.logger.error(f"LLM call failed: {str(e)}")
|
|
raise
|
|
```
|
|
|
|
### Rate Limiting
|
|
|
|
Implementing rate limiting can help avoid overwhelming the LLM API:
|
|
|
|
```python
|
|
import time
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
class RateLimitedLLM(BaseLLM):
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
endpoint: str,
|
|
requests_per_minute: int = 60
|
|
):
|
|
super().__init__()
|
|
self.api_key = api_key
|
|
self.endpoint = endpoint
|
|
self.requests_per_minute = requests_per_minute
|
|
self.request_times: List[float] = []
|
|
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
self._enforce_rate_limit()
|
|
# Record this request time
|
|
self.request_times.append(time.time())
|
|
# Make the actual API call
|
|
return self._make_api_call(messages, tools)
|
|
|
|
def _enforce_rate_limit(self) -> None:
|
|
"""Enforce the rate limit by waiting if necessary."""
|
|
now = time.time()
|
|
# Remove request times older than 1 minute
|
|
self.request_times = [t for t in self.request_times if now - t < 60]
|
|
|
|
if len(self.request_times) >= self.requests_per_minute:
|
|
# Calculate how long to wait
|
|
oldest_request = min(self.request_times)
|
|
wait_time = 60 - (now - oldest_request)
|
|
if wait_time > 0:
|
|
time.sleep(wait_time)
|
|
```
|
|
|
|
### Metrics Collection
|
|
|
|
Collecting metrics can help you monitor your LLM usage:
|
|
|
|
```python
|
|
import time
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
class MetricsCollectingLLM(BaseLLM):
|
|
def __init__(self, api_key: str, endpoint: str):
|
|
super().__init__()
|
|
self.api_key = api_key
|
|
self.endpoint = endpoint
|
|
self.metrics: Dict[str, Any] = {
|
|
"total_calls": 0,
|
|
"total_tokens": 0,
|
|
"errors": 0,
|
|
"latency": []
|
|
}
|
|
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
start_time = time.time()
|
|
self.metrics["total_calls"] += 1
|
|
|
|
try:
|
|
response = self._make_api_call(messages, tools)
|
|
# Estimate tokens (simplified)
|
|
if isinstance(messages, str):
|
|
token_estimate = len(messages) // 4
|
|
else:
|
|
token_estimate = sum(len(m.get("content", "")) // 4 for m in messages)
|
|
self.metrics["total_tokens"] += token_estimate
|
|
return response
|
|
except Exception as e:
|
|
self.metrics["errors"] += 1
|
|
raise
|
|
finally:
|
|
latency = time.time() - start_time
|
|
self.metrics["latency"].append(latency)
|
|
|
|
def get_metrics(self) -> Dict[str, Any]:
|
|
"""Return the collected metrics."""
|
|
avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0
|
|
return {
|
|
**self.metrics,
|
|
"avg_latency": avg_latency
|
|
}
|
|
```
|
|
|
|
## Advanced Usage: Function Calling
|
|
|
|
If your LLM supports function calling, you can implement the function calling logic in your custom LLM:
|
|
|
|
```python
|
|
import json
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
def call(
|
|
self,
|
|
messages: Union[str, List[Dict[str, str]]],
|
|
tools: Optional[List[dict]] = None,
|
|
callbacks: Optional[List[Any]] = None,
|
|
available_functions: Optional[Dict[str, Any]] = None,
|
|
) -> Union[str, Any]:
|
|
import requests
|
|
|
|
try:
|
|
headers = {
|
|
"Authorization": f"Bearer {self.jwt_token}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
# Convert string message to proper format if needed
|
|
if isinstance(messages, str):
|
|
messages = [{"role": "user", "content": messages}]
|
|
|
|
data = {
|
|
"messages": messages,
|
|
"tools": tools
|
|
}
|
|
|
|
response = requests.post(
|
|
self.endpoint,
|
|
headers=headers,
|
|
json=data,
|
|
timeout=30
|
|
)
|
|
response.raise_for_status()
|
|
response_data = response.json()
|
|
|
|
# Check if the LLM wants to call a function
|
|
if response_data["choices"][0]["message"].get("tool_calls"):
|
|
tool_calls = response_data["choices"][0]["message"]["tool_calls"]
|
|
|
|
# Process each tool call
|
|
for tool_call in tool_calls:
|
|
function_name = tool_call["function"]["name"]
|
|
function_args = json.loads(tool_call["function"]["arguments"])
|
|
|
|
if available_functions and function_name in available_functions:
|
|
function_to_call = available_functions[function_name]
|
|
function_response = function_to_call(**function_args)
|
|
|
|
# Add the function response to the messages
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call["id"],
|
|
"name": function_name,
|
|
"content": str(function_response)
|
|
})
|
|
|
|
# Call the LLM again with the updated messages
|
|
return self.call(messages, tools, callbacks, available_functions)
|
|
|
|
# Return the text response if no function call
|
|
return response_data["choices"][0]["message"]["content"]
|
|
except requests.Timeout:
|
|
raise TimeoutError("LLM request timed out")
|
|
except requests.RequestException as e:
|
|
raise RuntimeError(f"LLM request failed: {str(e)}")
|
|
except (KeyError, IndexError, ValueError) as e:
|
|
raise ValueError(f"Invalid response format: {str(e)}")
|
|
```
|
|
|
|
## Using Your Custom LLM with CrewAI
|
|
|
|
Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:
|
|
|
|
```python
|
|
from crewai import Agent, Task, Crew
|
|
from typing import Dict, Any
|
|
|
|
# Create your custom LLM instance
|
|
jwt_llm = JWTAuthLLM(
|
|
jwt_token="your.jwt.token",
|
|
endpoint="https://your-llm-endpoint.com/v1/chat/completions"
|
|
)
|
|
|
|
# Use it with an agent
|
|
agent = Agent(
|
|
role="Research Assistant",
|
|
goal="Find information on a topic",
|
|
backstory="You are a research assistant tasked with finding information.",
|
|
llm=jwt_llm,
|
|
)
|
|
|
|
# Create a task for the agent
|
|
task = Task(
|
|
description="Research the benefits of exercise",
|
|
agent=agent,
|
|
expected_output="A summary of the benefits of exercise",
|
|
)
|
|
|
|
# Execute the task
|
|
result = agent.execute_task(task)
|
|
print(result)
|
|
|
|
# Or use it with a crew
|
|
crew = Crew(
|
|
agents=[agent],
|
|
tasks=[task],
|
|
manager_llm=jwt_llm, # Use your custom LLM for the manager
|
|
)
|
|
|
|
# Run the crew
|
|
result = crew.kickoff()
|
|
print(result)
|
|
```
|
|
|
|
## Implementing Your Own Authentication Mechanism
|
|
|
|
The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
|
|
|
|
- OAuth tokens
|
|
- Client certificates
|
|
- Custom headers
|
|
- Session-based authentication
|
|
- Any other authentication method required by your LLM provider
|
|
|
|
Simply implement the appropriate authentication logic in your custom LLM class.
|