mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Add support for custom LLM implementations (#2277)
* Add support for custom LLM implementations Co-Authored-By: Joe Moura <joao@crewai.com> * Fix import sorting and type annotations Co-Authored-By: Joe Moura <joao@crewai.com> * Fix linting issues with import sorting Co-Authored-By: Joe Moura <joao@crewai.com> * Fix type errors in crew.py by updating tool-related methods to return List[BaseTool] Co-Authored-By: Joe Moura <joao@crewai.com> * Enhance custom LLM implementation with better error handling, documentation, and test coverage Co-Authored-By: Joe Moura <joao@crewai.com> * Refactor LLM module by extracting BaseLLM to a separate file This commit moves the BaseLLM abstract base class from llm.py to a new file llms/base_llm.py to improve code organization. The changes include: - Creating a new file src/crewai/llms/base_llm.py - Moving the BaseLLM class to the new file - Updating imports in __init__.py and llm.py to reflect the new location - Updating test cases to use the new import path The refactoring maintains the existing functionality while improving the project's module structure. * Add AISuite LLM support and update dependencies - Integrate AISuite as a new third-party LLM option - Update pyproject.toml and uv.lock to include aisuite package - Modify BaseLLM to support more flexible initialization - Remove unnecessary LLM imports across multiple files - Implement AISuiteLLM with basic chat completion functionality * Update AISuiteLLM and LLM utility type handling - Modify AISuiteLLM to support more flexible input types for messages - Update type hints in AISuiteLLM to allow string or list of message dictionaries - Enhance LLM utility function to support broader LLM type annotations - Remove default `self.stop` attribute from BaseLLM initialization * Update LLM imports and type hints across multiple files - Modify imports in crew_chat.py to use LLM instead of BaseLLM - Update type hints in llm_utils.py to use LLM type - Add optional `stop` parameter to BaseLLM initialization - Refactor type handling for LLM creation and usage * Improve stop words handling in CrewAgentExecutor - Add support for handling existing stop words in LLM configuration - Ensure stop words are correctly merged and deduplicated - Update type hints to support both LLM and BaseLLM types * Remove abstract method set_callbacks from BaseLLM class * Enhance CustomLLM and JWTAuthLLM initialization with model parameter - Update CustomLLM to accept a model parameter during initialization - Modify test cases to include the new model argument - Ensure JWTAuthLLM and TimeoutHandlingLLM also utilize the model parameter in their constructors - Update type hints in create_llm function to support both LLM and BaseLLM types * Enhance create_llm function to support BaseLLM type - Update the create_llm function to accept both LLM and BaseLLM instances - Ensure compatibility with existing LLM handling logic * Update type hint for initialize_chat_llm to support BaseLLM - Modify the return type of initialize_chat_llm function to allow for both LLM and BaseLLM instances - Ensure compatibility with recent changes in create_llm function * Refactor AISuiteLLM to include tools parameter in completion methods - Update the _prepare_completion_params method to accept an optional tools parameter - Modify the chat completion method to utilize the new tools parameter for enhanced functionality - Clean up print statements for better code clarity * Remove unused tool_calls handling in AISuiteLLM chat completion method for cleaner code. * Refactor Crew class and LLM hierarchy for improved type handling and code clarity - Update Crew class methods to enhance readability with consistent formatting and type hints. - Change LLM class to inherit from BaseLLM for better structure. - Remove unnecessary type checks and streamline tool handling in CrewAgentExecutor. - Adjust BaseLLM to provide default implementations for stop words and context window size methods. - Clean up AISuiteLLM by removing unused methods related to stop words and context window size. * Remove unused `stream` method from `BaseLLM` class to enhance code clarity and maintainability. --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joe Moura <joao@crewai.com> Co-authored-by: Lorenze Jay <lorenzejaytech@gmail.com> Co-authored-by: João Moura <joaomdmoura@gmail.com> Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3dea3d0183
commit
807c13e144
642
docs/custom_llm.md
Normal file
642
docs/custom_llm.md
Normal file
@@ -0,0 +1,642 @@
|
|||||||
|
# Custom LLM Implementations
|
||||||
|
|
||||||
|
CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
|
||||||
|
|
||||||
|
## Using Custom LLM Implementations
|
||||||
|
|
||||||
|
To create a custom LLM implementation, you need to:
|
||||||
|
|
||||||
|
1. Inherit from the `BaseLLM` abstract base class
|
||||||
|
2. Implement the required methods:
|
||||||
|
- `call()`: The main method to call the LLM with messages
|
||||||
|
- `supports_function_calling()`: Whether the LLM supports function calling
|
||||||
|
- `supports_stop_words()`: Whether the LLM supports stop words
|
||||||
|
- `get_context_window_size()`: The context window size of the LLM
|
||||||
|
|
||||||
|
## Example: Basic Custom LLM
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import BaseLLM
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
class CustomLLM(BaseLLM):
|
||||||
|
def __init__(self, api_key: str, endpoint: str):
|
||||||
|
super().__init__() # Initialize the base class to set default attributes
|
||||||
|
if not api_key or not isinstance(api_key, str):
|
||||||
|
raise ValueError("Invalid API key: must be a non-empty string")
|
||||||
|
if not endpoint or not isinstance(endpoint, str):
|
||||||
|
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
||||||
|
self.api_key = api_key
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.stop = [] # You can customize stop words if needed
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
"""Call the LLM with the given messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages for the LLM.
|
||||||
|
tools: Optional list of tool schemas for function calling.
|
||||||
|
callbacks: Optional list of callback functions.
|
||||||
|
available_functions: Optional dict mapping function names to callables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a text response from the LLM or the result of a tool function call.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If the LLM request times out.
|
||||||
|
RuntimeError: If the LLM request fails for other reasons.
|
||||||
|
ValueError: If the response format is invalid.
|
||||||
|
"""
|
||||||
|
# Implement your own logic to call the LLM
|
||||||
|
# For example, using requests:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert string message to proper format if needed
|
||||||
|
if isinstance(messages, str):
|
||||||
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.endpoint,
|
||||||
|
headers=headers,
|
||||||
|
json=data,
|
||||||
|
timeout=30 # Set a reasonable timeout
|
||||||
|
)
|
||||||
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
||||||
|
except requests.Timeout:
|
||||||
|
raise TimeoutError("LLM request timed out")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||||
|
except (KeyError, IndexError, ValueError) as e:
|
||||||
|
raise ValueError(f"Invalid response format: {str(e)}")
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Check if the LLM supports function calling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the LLM supports function calling, False otherwise.
|
||||||
|
"""
|
||||||
|
# Return True if your LLM supports function calling
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Check if the LLM supports stop words.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the LLM supports stop words, False otherwise.
|
||||||
|
"""
|
||||||
|
# Return True if your LLM supports stop words
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Get the context window size of the LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The context window size as an integer.
|
||||||
|
"""
|
||||||
|
# Return the context window size of your LLM
|
||||||
|
return 8192
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling Best Practices
|
||||||
|
|
||||||
|
When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices:
|
||||||
|
|
||||||
|
### 1. Implement Try-Except Blocks for API Calls
|
||||||
|
|
||||||
|
Always wrap API calls in try-except blocks to handle different types of errors:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
try:
|
||||||
|
# API call implementation
|
||||||
|
response = requests.post(
|
||||||
|
self.endpoint,
|
||||||
|
headers=self.headers,
|
||||||
|
json=self.prepare_payload(messages),
|
||||||
|
timeout=30 # Set a reasonable timeout
|
||||||
|
)
|
||||||
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
||||||
|
except requests.Timeout:
|
||||||
|
raise TimeoutError("LLM request timed out")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||||
|
except (KeyError, IndexError, ValueError) as e:
|
||||||
|
raise ValueError(f"Invalid response format: {str(e)}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Implement Retry Logic for Transient Failures
|
||||||
|
|
||||||
|
For transient failures like network issues or rate limiting, implement retry logic with exponential backoff:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
import time
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
retry_delay = 1 # seconds
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.endpoint,
|
||||||
|
headers=self.headers,
|
||||||
|
json=self.prepare_payload(messages),
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
||||||
|
except (requests.Timeout, requests.ConnectionError) as e:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(retry_delay * (2 ** attempt)) # Exponential backoff
|
||||||
|
continue
|
||||||
|
raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Validate Input Parameters
|
||||||
|
|
||||||
|
Always validate input parameters to prevent runtime errors:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def __init__(self, api_key: str, endpoint: str):
|
||||||
|
super().__init__()
|
||||||
|
if not api_key or not isinstance(api_key, str):
|
||||||
|
raise ValueError("Invalid API key: must be a non-empty string")
|
||||||
|
if not endpoint or not isinstance(endpoint, str):
|
||||||
|
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
||||||
|
self.api_key = api_key
|
||||||
|
self.endpoint = endpoint
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Handle Authentication Errors Gracefully
|
||||||
|
|
||||||
|
Provide clear error messages for authentication failures:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
try:
|
||||||
|
response = requests.post(self.endpoint, headers=self.headers, json=data)
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise ValueError("Authentication failed: Invalid API key or token")
|
||||||
|
elif response.status_code == 403:
|
||||||
|
raise ValueError("Authorization failed: Insufficient permissions")
|
||||||
|
response.raise_for_status()
|
||||||
|
# Process response
|
||||||
|
except Exception as e:
|
||||||
|
# Handle error
|
||||||
|
raise
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example: JWT-based Authentication
|
||||||
|
|
||||||
|
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import BaseLLM, Agent, Task
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
class JWTAuthLLM(BaseLLM):
|
||||||
|
def __init__(self, jwt_token: str, endpoint: str):
|
||||||
|
super().__init__() # Initialize the base class to set default attributes
|
||||||
|
if not jwt_token or not isinstance(jwt_token, str):
|
||||||
|
raise ValueError("Invalid JWT token: must be a non-empty string")
|
||||||
|
if not endpoint or not isinstance(endpoint, str):
|
||||||
|
raise ValueError("Invalid endpoint URL: must be a non-empty string")
|
||||||
|
self.jwt_token = jwt_token
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.stop = [] # You can customize stop words if needed
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
"""Call the LLM with JWT authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages for the LLM.
|
||||||
|
tools: Optional list of tool schemas for function calling.
|
||||||
|
callbacks: Optional list of callback functions.
|
||||||
|
available_functions: Optional dict mapping function names to callables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a text response from the LLM or the result of a tool function call.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If the LLM request times out.
|
||||||
|
RuntimeError: If the LLM request fails for other reasons.
|
||||||
|
ValueError: If the response format is invalid.
|
||||||
|
"""
|
||||||
|
# Implement your own logic to call the LLM with JWT authentication
|
||||||
|
import requests
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.jwt_token}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert string message to proper format if needed
|
||||||
|
if isinstance(messages, str):
|
||||||
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.endpoint,
|
||||||
|
headers=headers,
|
||||||
|
json=data,
|
||||||
|
timeout=30 # Set a reasonable timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise ValueError("Authentication failed: Invalid JWT token")
|
||||||
|
elif response.status_code == 403:
|
||||||
|
raise ValueError("Authorization failed: Insufficient permissions")
|
||||||
|
|
||||||
|
response.raise_for_status() # Raise an exception for HTTP errors
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
||||||
|
except requests.Timeout:
|
||||||
|
raise TimeoutError("LLM request timed out")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||||
|
except (KeyError, IndexError, ValueError) as e:
|
||||||
|
raise ValueError(f"Invalid response format: {str(e)}")
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Check if the LLM supports function calling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the LLM supports function calling, False otherwise.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Check if the LLM supports stop words.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the LLM supports stop words, False otherwise.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Get the context window size of the LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The context window size as an integer.
|
||||||
|
"""
|
||||||
|
return 8192
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
Here are some common issues you might encounter when implementing custom LLMs and how to resolve them:
|
||||||
|
|
||||||
|
### 1. Authentication Failures
|
||||||
|
|
||||||
|
**Symptoms**: 401 Unauthorized or 403 Forbidden errors
|
||||||
|
|
||||||
|
**Solutions**:
|
||||||
|
- Verify that your API key or JWT token is valid and not expired
|
||||||
|
- Check that you're using the correct authentication header format
|
||||||
|
- Ensure that your token has the necessary permissions
|
||||||
|
|
||||||
|
### 2. Timeout Issues
|
||||||
|
|
||||||
|
**Symptoms**: Requests taking too long or timing out
|
||||||
|
|
||||||
|
**Solutions**:
|
||||||
|
- Implement timeout handling as shown in the examples
|
||||||
|
- Use retry logic with exponential backoff
|
||||||
|
- Consider using a more reliable network connection
|
||||||
|
|
||||||
|
### 3. Response Parsing Errors
|
||||||
|
|
||||||
|
**Symptoms**: KeyError, IndexError, or ValueError when processing responses
|
||||||
|
|
||||||
|
**Solutions**:
|
||||||
|
- Validate the response format before accessing nested fields
|
||||||
|
- Implement proper error handling for malformed responses
|
||||||
|
- Check the API documentation for the expected response format
|
||||||
|
|
||||||
|
### 4. Rate Limiting
|
||||||
|
|
||||||
|
**Symptoms**: 429 Too Many Requests errors
|
||||||
|
|
||||||
|
**Solutions**:
|
||||||
|
- Implement rate limiting in your custom LLM
|
||||||
|
- Add exponential backoff for retries
|
||||||
|
- Consider using a token bucket algorithm for more precise rate control
|
||||||
|
|
||||||
|
## Advanced Features
|
||||||
|
|
||||||
|
### Logging
|
||||||
|
|
||||||
|
Adding logging to your custom LLM can help with debugging and monitoring:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
class LoggingLLM(BaseLLM):
|
||||||
|
def __init__(self, api_key: str, endpoint: str):
|
||||||
|
super().__init__()
|
||||||
|
self.api_key = api_key
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.logger = logging.getLogger("crewai.llm.custom")
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages")
|
||||||
|
try:
|
||||||
|
# API call implementation
|
||||||
|
response = self._make_api_call(messages, tools)
|
||||||
|
self.logger.debug(f"LLM response received: {response[:100]}...")
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"LLM call failed: {str(e)}")
|
||||||
|
raise
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rate Limiting
|
||||||
|
|
||||||
|
Implementing rate limiting can help avoid overwhelming the LLM API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
class RateLimitedLLM(BaseLLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
endpoint: str,
|
||||||
|
requests_per_minute: int = 60
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.api_key = api_key
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.requests_per_minute = requests_per_minute
|
||||||
|
self.request_times: List[float] = []
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
self._enforce_rate_limit()
|
||||||
|
# Record this request time
|
||||||
|
self.request_times.append(time.time())
|
||||||
|
# Make the actual API call
|
||||||
|
return self._make_api_call(messages, tools)
|
||||||
|
|
||||||
|
def _enforce_rate_limit(self) -> None:
|
||||||
|
"""Enforce the rate limit by waiting if necessary."""
|
||||||
|
now = time.time()
|
||||||
|
# Remove request times older than 1 minute
|
||||||
|
self.request_times = [t for t in self.request_times if now - t < 60]
|
||||||
|
|
||||||
|
if len(self.request_times) >= self.requests_per_minute:
|
||||||
|
# Calculate how long to wait
|
||||||
|
oldest_request = min(self.request_times)
|
||||||
|
wait_time = 60 - (now - oldest_request)
|
||||||
|
if wait_time > 0:
|
||||||
|
time.sleep(wait_time)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Metrics Collection
|
||||||
|
|
||||||
|
Collecting metrics can help you monitor your LLM usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
class MetricsCollectingLLM(BaseLLM):
|
||||||
|
def __init__(self, api_key: str, endpoint: str):
|
||||||
|
super().__init__()
|
||||||
|
self.api_key = api_key
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.metrics: Dict[str, Any] = {
|
||||||
|
"total_calls": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
"errors": 0,
|
||||||
|
"latency": []
|
||||||
|
}
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
start_time = time.time()
|
||||||
|
self.metrics["total_calls"] += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self._make_api_call(messages, tools)
|
||||||
|
# Estimate tokens (simplified)
|
||||||
|
if isinstance(messages, str):
|
||||||
|
token_estimate = len(messages) // 4
|
||||||
|
else:
|
||||||
|
token_estimate = sum(len(m.get("content", "")) // 4 for m in messages)
|
||||||
|
self.metrics["total_tokens"] += token_estimate
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
self.metrics["errors"] += 1
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
latency = time.time() - start_time
|
||||||
|
self.metrics["latency"].append(latency)
|
||||||
|
|
||||||
|
def get_metrics(self) -> Dict[str, Any]:
|
||||||
|
"""Return the collected metrics."""
|
||||||
|
avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0
|
||||||
|
return {
|
||||||
|
**self.metrics,
|
||||||
|
"avg_latency": avg_latency
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Usage: Function Calling
|
||||||
|
|
||||||
|
If your LLM supports function calling, you can implement the function calling logic in your custom LLM:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.jwt_token}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert string message to proper format if needed
|
||||||
|
if isinstance(messages, str):
|
||||||
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.endpoint,
|
||||||
|
headers=headers,
|
||||||
|
json=data,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
# Check if the LLM wants to call a function
|
||||||
|
if response_data["choices"][0]["message"].get("tool_calls"):
|
||||||
|
tool_calls = response_data["choices"][0]["message"]["tool_calls"]
|
||||||
|
|
||||||
|
# Process each tool call
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
function_name = tool_call["function"]["name"]
|
||||||
|
function_args = json.loads(tool_call["function"]["arguments"])
|
||||||
|
|
||||||
|
if available_functions and function_name in available_functions:
|
||||||
|
function_to_call = available_functions[function_name]
|
||||||
|
function_response = function_to_call(**function_args)
|
||||||
|
|
||||||
|
# Add the function response to the messages
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call["id"],
|
||||||
|
"name": function_name,
|
||||||
|
"content": str(function_response)
|
||||||
|
})
|
||||||
|
|
||||||
|
# Call the LLM again with the updated messages
|
||||||
|
return self.call(messages, tools, callbacks, available_functions)
|
||||||
|
|
||||||
|
# Return the text response if no function call
|
||||||
|
return response_data["choices"][0]["message"]["content"]
|
||||||
|
except requests.Timeout:
|
||||||
|
raise TimeoutError("LLM request timed out")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise RuntimeError(f"LLM request failed: {str(e)}")
|
||||||
|
except (KeyError, IndexError, ValueError) as e:
|
||||||
|
raise ValueError(f"Invalid response format: {str(e)}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using Your Custom LLM with CrewAI
|
||||||
|
|
||||||
|
Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import Agent, Task, Crew
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
# Create your custom LLM instance
|
||||||
|
jwt_llm = JWTAuthLLM(
|
||||||
|
jwt_token="your.jwt.token",
|
||||||
|
endpoint="https://your-llm-endpoint.com/v1/chat/completions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use it with an agent
|
||||||
|
agent = Agent(
|
||||||
|
role="Research Assistant",
|
||||||
|
goal="Find information on a topic",
|
||||||
|
backstory="You are a research assistant tasked with finding information.",
|
||||||
|
llm=jwt_llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a task for the agent
|
||||||
|
task = Task(
|
||||||
|
description="Research the benefits of exercise",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="A summary of the benefits of exercise",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the task
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# Or use it with a crew
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
manager_llm=jwt_llm, # Use your custom LLM for the manager
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the crew
|
||||||
|
result = crew.kickoff()
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementing Your Own Authentication Mechanism
|
||||||
|
|
||||||
|
The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
|
||||||
|
|
||||||
|
- OAuth tokens
|
||||||
|
- Client certificates
|
||||||
|
- Custom headers
|
||||||
|
- Session-based authentication
|
||||||
|
- Any other authentication method required by your LLM provider
|
||||||
|
|
||||||
|
Simply implement the appropriate authentication logic in your custom LLM class.
|
||||||
@@ -64,6 +64,9 @@ mem0 = ["mem0ai>=0.1.29"]
|
|||||||
docling = [
|
docling = [
|
||||||
"docling>=2.12.0",
|
"docling>=2.12.0",
|
||||||
]
|
]
|
||||||
|
aisuite = [
|
||||||
|
"aisuite>=0.1.10",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
dev-dependencies = [
|
dev-dependencies = [
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from crewai.crew import Crew
|
|||||||
from crewai.flow.flow import Flow
|
from crewai.flow.flow import Flow
|
||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.process import Process
|
from crewai.process import Process
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
|
|
||||||
@@ -21,6 +22,7 @@ __all__ = [
|
|||||||
"Process",
|
"Process",
|
||||||
"Task",
|
"Task",
|
||||||
"LLM",
|
"LLM",
|
||||||
|
"BaseLLM",
|
||||||
"Flow",
|
"Flow",
|
||||||
"Knowledge",
|
"Knowledge",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
|||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||||
from crewai.llm import LLM
|
from crewai.llm import BaseLLM
|
||||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||||
from crewai.security import Fingerprint
|
from crewai.security import Fingerprint
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
@@ -71,10 +71,10 @@ class Agent(BaseAgent):
|
|||||||
default=True,
|
default=True,
|
||||||
description="Use system prompt for the agent.",
|
description="Use system prompt for the agent.",
|
||||||
)
|
)
|
||||||
llm: Union[str, InstanceOf[LLM], Any] = Field(
|
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
system_template: Optional[str] = Field(
|
system_template: Optional[str] = Field(
|
||||||
@@ -118,7 +118,9 @@ class Agent(BaseAgent):
|
|||||||
self.agent_ops_agent_name = self.role
|
self.agent_ops_agent_name = self.role
|
||||||
|
|
||||||
self.llm = create_llm(self.llm)
|
self.llm = create_llm(self.llm)
|
||||||
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
|
if self.function_calling_llm and not isinstance(
|
||||||
|
self.function_calling_llm, BaseLLM
|
||||||
|
):
|
||||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||||
|
|
||||||
if not self.agent_executor:
|
if not self.agent_executor:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from crewai.agents.parser import (
|
|||||||
OutputParserException,
|
OutputParserException,
|
||||||
)
|
)
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
from crewai.llm import LLM
|
from crewai.llm import BaseLLM
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
||||||
from crewai.utilities import I18N, Printer
|
from crewai.utilities import I18N, Printer
|
||||||
@@ -61,7 +61,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
callbacks: List[Any] = [],
|
callbacks: List[Any] = [],
|
||||||
):
|
):
|
||||||
self._i18n: I18N = I18N()
|
self._i18n: I18N = I18N()
|
||||||
self.llm: LLM = llm
|
self.llm: BaseLLM = llm
|
||||||
self.task = task
|
self.task = task
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.crew = crew
|
self.crew = crew
|
||||||
@@ -87,8 +87,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self.tool_name_to_tool_map: Dict[str, BaseTool] = {
|
self.tool_name_to_tool_map: Dict[str, BaseTool] = {
|
||||||
tool.name: tool for tool in self.tools
|
tool.name: tool for tool in self.tools
|
||||||
}
|
}
|
||||||
self.stop = stop_words
|
existing_stop = self.llm.stop or []
|
||||||
self.llm.stop = list(set(self.llm.stop + self.stop))
|
self.llm.stop = list(
|
||||||
|
set(
|
||||||
|
existing_stop + self.stop
|
||||||
|
if isinstance(existing_stop, list)
|
||||||
|
else self.stop
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||||
if "system" in self.prompt:
|
if "system" in self.prompt:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from packaging import version
|
|||||||
from crewai.cli.utils import read_toml
|
from crewai.cli.utils import read_toml
|
||||||
from crewai.cli.version import get_crewai_version
|
from crewai.cli.version import get_crewai_version
|
||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM, BaseLLM
|
||||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||||
from crewai.utilities.llm_utils import create_llm
|
from crewai.utilities.llm_utils import create_llm
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ def show_loading(event: threading.Event):
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def initialize_chat_llm(crew: Crew) -> Optional[LLM]:
|
def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]:
|
||||||
"""Initializes the chat LLM and handles exceptions."""
|
"""Initializes the chat LLM and handles exceptions."""
|
||||||
try:
|
try:
|
||||||
return create_llm(crew.chat_llm)
|
return create_llm(crew.chat_llm)
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ import warnings
|
|||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from copy import copy as shallow_copy
|
from copy import copy as shallow_copy
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast
|
||||||
|
|
||||||
|
from langchain_core.tools import BaseTool as LangchainBaseTool
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -26,7 +27,7 @@ from crewai.agents.cache import CacheHandler
|
|||||||
from crewai.crews.crew_output import CrewOutput
|
from crewai.crews.crew_output import CrewOutput
|
||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM, BaseLLM
|
||||||
from crewai.memory.entity.entity_memory import EntityMemory
|
from crewai.memory.entity.entity_memory import EntityMemory
|
||||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||||
@@ -37,7 +38,7 @@ from crewai.task import Task
|
|||||||
from crewai.tasks.conditional_task import ConditionalTask
|
from crewai.tasks.conditional_task import ConditionalTask
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||||
from crewai.tools.base_tool import Tool
|
from crewai.tools.base_tool import BaseTool, Tool
|
||||||
from crewai.types.usage_metrics import UsageMetrics
|
from crewai.types.usage_metrics import UsageMetrics
|
||||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||||
@@ -153,7 +154,7 @@ class Crew(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Metrics for the LLM usage during all tasks execution.",
|
description="Metrics for the LLM usage during all tasks execution.",
|
||||||
)
|
)
|
||||||
manager_llm: Optional[Any] = Field(
|
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
manager_agent: Optional[BaseAgent] = Field(
|
manager_agent: Optional[BaseAgent] = Field(
|
||||||
@@ -187,7 +188,7 @@ class Crew(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Maximum number of requests per minute for the crew execution to be respected.",
|
description="Maximum number of requests per minute for the crew execution to be respected.",
|
||||||
)
|
)
|
||||||
prompt_file: str = Field(
|
prompt_file: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to the prompt json file to be used for the crew.",
|
description="Path to the prompt json file to be used for the crew.",
|
||||||
)
|
)
|
||||||
@@ -199,7 +200,7 @@ class Crew(BaseModel):
|
|||||||
default=False,
|
default=False,
|
||||||
description="Plan the crew execution and add the plan to the crew.",
|
description="Plan the crew execution and add the plan to the crew.",
|
||||||
)
|
)
|
||||||
planning_llm: Optional[Any] = Field(
|
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Language model that will run the AgentPlanner if planning is True.",
|
description="Language model that will run the AgentPlanner if planning is True.",
|
||||||
)
|
)
|
||||||
@@ -215,7 +216,7 @@ class Crew(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||||
)
|
)
|
||||||
chat_llm: Optional[Any] = Field(
|
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="LLM used to handle chatting with the crew.",
|
description="LLM used to handle chatting with the crew.",
|
||||||
)
|
)
|
||||||
@@ -819,7 +820,12 @@ class Crew(BaseModel):
|
|||||||
|
|
||||||
# Determine which tools to use - task tools take precedence over agent tools
|
# Determine which tools to use - task tools take precedence over agent tools
|
||||||
tools_for_task = task.tools or agent_to_use.tools or []
|
tools_for_task = task.tools or agent_to_use.tools or []
|
||||||
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)
|
# Prepare tools and ensure they're compatible with task execution
|
||||||
|
tools_for_task = self._prepare_tools(
|
||||||
|
agent_to_use,
|
||||||
|
task,
|
||||||
|
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
|
||||||
|
)
|
||||||
|
|
||||||
self._log_task_start(task, agent_to_use.role)
|
self._log_task_start(task, agent_to_use.role)
|
||||||
|
|
||||||
@@ -838,7 +844,7 @@ class Crew(BaseModel):
|
|||||||
future = task.execute_async(
|
future = task.execute_async(
|
||||||
agent=agent_to_use,
|
agent=agent_to_use,
|
||||||
context=context,
|
context=context,
|
||||||
tools=tools_for_task,
|
tools=cast(List[BaseTool], tools_for_task),
|
||||||
)
|
)
|
||||||
futures.append((task, future, task_index))
|
futures.append((task, future, task_index))
|
||||||
else:
|
else:
|
||||||
@@ -850,7 +856,7 @@ class Crew(BaseModel):
|
|||||||
task_output = task.execute_sync(
|
task_output = task.execute_sync(
|
||||||
agent=agent_to_use,
|
agent=agent_to_use,
|
||||||
context=context,
|
context=context,
|
||||||
tools=tools_for_task,
|
tools=cast(List[BaseTool], tools_for_task),
|
||||||
)
|
)
|
||||||
task_outputs.append(task_output)
|
task_outputs.append(task_output)
|
||||||
self._process_task_result(task, task_output)
|
self._process_task_result(task, task_output)
|
||||||
@@ -888,10 +894,12 @@ class Crew(BaseModel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _prepare_tools(
|
def _prepare_tools(
|
||||||
self, agent: BaseAgent, task: Task, tools: List[Tool]
|
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||||
) -> List[Tool]:
|
) -> List[BaseTool]:
|
||||||
# Add delegation tools if agent allows delegation
|
# Add delegation tools if agent allows delegation
|
||||||
if agent.allow_delegation:
|
if hasattr(agent, "allow_delegation") and getattr(
|
||||||
|
agent, "allow_delegation", False
|
||||||
|
):
|
||||||
if self.process == Process.hierarchical:
|
if self.process == Process.hierarchical:
|
||||||
if self.manager_agent:
|
if self.manager_agent:
|
||||||
tools = self._update_manager_tools(task, tools)
|
tools = self._update_manager_tools(task, tools)
|
||||||
@@ -900,17 +908,24 @@ class Crew(BaseModel):
|
|||||||
"Manager agent is required for hierarchical process."
|
"Manager agent is required for hierarchical process."
|
||||||
)
|
)
|
||||||
|
|
||||||
elif agent and agent.allow_delegation:
|
elif agent:
|
||||||
tools = self._add_delegation_tools(task, tools)
|
tools = self._add_delegation_tools(task, tools)
|
||||||
|
|
||||||
# Add code execution tools if agent allows code execution
|
# Add code execution tools if agent allows code execution
|
||||||
if agent.allow_code_execution:
|
if hasattr(agent, "allow_code_execution") and getattr(
|
||||||
|
agent, "allow_code_execution", False
|
||||||
|
):
|
||||||
tools = self._add_code_execution_tools(agent, tools)
|
tools = self._add_code_execution_tools(agent, tools)
|
||||||
|
|
||||||
if agent and agent.multimodal:
|
if (
|
||||||
|
agent
|
||||||
|
and hasattr(agent, "multimodal")
|
||||||
|
and getattr(agent, "multimodal", False)
|
||||||
|
):
|
||||||
tools = self._add_multimodal_tools(agent, tools)
|
tools = self._add_multimodal_tools(agent, tools)
|
||||||
|
|
||||||
return tools
|
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
|
||||||
|
return cast(List[BaseTool], tools)
|
||||||
|
|
||||||
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
|
def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]:
|
||||||
if self.process == Process.hierarchical:
|
if self.process == Process.hierarchical:
|
||||||
@@ -918,11 +933,13 @@ class Crew(BaseModel):
|
|||||||
return task.agent
|
return task.agent
|
||||||
|
|
||||||
def _merge_tools(
|
def _merge_tools(
|
||||||
self, existing_tools: List[Tool], new_tools: List[Tool]
|
self,
|
||||||
) -> List[Tool]:
|
existing_tools: Union[List[Tool], List[BaseTool]],
|
||||||
|
new_tools: Union[List[Tool], List[BaseTool]],
|
||||||
|
) -> List[BaseTool]:
|
||||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||||
if not new_tools:
|
if not new_tools:
|
||||||
return existing_tools
|
return cast(List[BaseTool], existing_tools)
|
||||||
|
|
||||||
# Create mapping of tool names to new tools
|
# Create mapping of tool names to new tools
|
||||||
new_tool_map = {tool.name: tool for tool in new_tools}
|
new_tool_map = {tool.name: tool for tool in new_tools}
|
||||||
@@ -933,23 +950,41 @@ class Crew(BaseModel):
|
|||||||
# Add all new tools
|
# Add all new tools
|
||||||
tools.extend(new_tools)
|
tools.extend(new_tools)
|
||||||
|
|
||||||
return tools
|
return cast(List[BaseTool], tools)
|
||||||
|
|
||||||
def _inject_delegation_tools(
|
def _inject_delegation_tools(
|
||||||
self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent]
|
self,
|
||||||
):
|
tools: Union[List[Tool], List[BaseTool]],
|
||||||
|
task_agent: BaseAgent,
|
||||||
|
agents: List[BaseAgent],
|
||||||
|
) -> List[BaseTool]:
|
||||||
|
if hasattr(task_agent, "get_delegation_tools"):
|
||||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||||
return self._merge_tools(tools, delegation_tools)
|
# Cast delegation_tools to the expected type for _merge_tools
|
||||||
|
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
|
||||||
|
return cast(List[BaseTool], tools)
|
||||||
|
|
||||||
def _add_multimodal_tools(self, agent: BaseAgent, tools: List[Tool]):
|
def _add_multimodal_tools(
|
||||||
|
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||||
|
) -> List[BaseTool]:
|
||||||
|
if hasattr(agent, "get_multimodal_tools"):
|
||||||
multimodal_tools = agent.get_multimodal_tools()
|
multimodal_tools = agent.get_multimodal_tools()
|
||||||
return self._merge_tools(tools, multimodal_tools)
|
# Cast multimodal_tools to the expected type for _merge_tools
|
||||||
|
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
|
||||||
|
return cast(List[BaseTool], tools)
|
||||||
|
|
||||||
def _add_code_execution_tools(self, agent: BaseAgent, tools: List[Tool]):
|
def _add_code_execution_tools(
|
||||||
|
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||||
|
) -> List[BaseTool]:
|
||||||
|
if hasattr(agent, "get_code_execution_tools"):
|
||||||
code_tools = agent.get_code_execution_tools()
|
code_tools = agent.get_code_execution_tools()
|
||||||
return self._merge_tools(tools, code_tools)
|
# Cast code_tools to the expected type for _merge_tools
|
||||||
|
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
|
||||||
|
return cast(List[BaseTool], tools)
|
||||||
|
|
||||||
def _add_delegation_tools(self, task: Task, tools: List[Tool]):
|
def _add_delegation_tools(
|
||||||
|
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||||
|
) -> List[BaseTool]:
|
||||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||||
if not tools:
|
if not tools:
|
||||||
@@ -957,7 +992,7 @@ class Crew(BaseModel):
|
|||||||
tools = self._inject_delegation_tools(
|
tools = self._inject_delegation_tools(
|
||||||
tools, task.agent, agents_for_delegation
|
tools, task.agent, agents_for_delegation
|
||||||
)
|
)
|
||||||
return tools
|
return cast(List[BaseTool], tools)
|
||||||
|
|
||||||
def _log_task_start(self, task: Task, role: str = "None"):
|
def _log_task_start(self, task: Task, role: str = "None"):
|
||||||
if self.output_log_file:
|
if self.output_log_file:
|
||||||
@@ -965,7 +1000,9 @@ class Crew(BaseModel):
|
|||||||
task_name=task.name, task=task.description, agent=role, status="started"
|
task_name=task.name, task=task.description, agent=role, status="started"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _update_manager_tools(self, task: Task, tools: List[Tool]):
|
def _update_manager_tools(
|
||||||
|
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||||
|
) -> List[BaseTool]:
|
||||||
if self.manager_agent:
|
if self.manager_agent:
|
||||||
if task.agent:
|
if task.agent:
|
||||||
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
|
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
|
||||||
@@ -973,7 +1010,7 @@ class Crew(BaseModel):
|
|||||||
tools = self._inject_delegation_tools(
|
tools = self._inject_delegation_tools(
|
||||||
tools, self.manager_agent, self.agents
|
tools, self.manager_agent, self.agents
|
||||||
)
|
)
|
||||||
return tools
|
return cast(List[BaseTool], tools)
|
||||||
|
|
||||||
def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
|
def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
|
||||||
context = (
|
context = (
|
||||||
@@ -1214,13 +1251,14 @@ class Crew(BaseModel):
|
|||||||
def test(
|
def test(
|
||||||
self,
|
self,
|
||||||
n_iterations: int,
|
n_iterations: int,
|
||||||
eval_llm: Union[str, InstanceOf[LLM]],
|
eval_llm: Union[str, InstanceOf[BaseLLM]],
|
||||||
inputs: Optional[Dict[str, Any]] = None,
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||||
try:
|
try:
|
||||||
eval_llm = create_llm(eval_llm)
|
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
|
||||||
if not eval_llm:
|
llm_instance = create_llm(eval_llm)
|
||||||
|
if not llm_instance:
|
||||||
raise ValueError("Failed to create LLM instance.")
|
raise ValueError("Failed to create LLM instance.")
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
@@ -1228,12 +1266,12 @@ class Crew(BaseModel):
|
|||||||
CrewTestStartedEvent(
|
CrewTestStartedEvent(
|
||||||
crew_name=self.name or "crew",
|
crew_name=self.name or "crew",
|
||||||
n_iterations=n_iterations,
|
n_iterations=n_iterations,
|
||||||
eval_llm=eval_llm,
|
eval_llm=llm_instance,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
test_crew = self.copy()
|
test_crew = self.copy()
|
||||||
evaluator = CrewEvaluator(test_crew, eval_llm) # type: ignore[arg-type]
|
evaluator = CrewEvaluator(test_crew, llm_instance)
|
||||||
|
|
||||||
for i in range(1, n_iterations + 1):
|
for i in range(1, n_iterations + 1):
|
||||||
evaluator.set_iteration(i)
|
evaluator.set_iteration(i)
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ with warnings.catch_warnings():
|
|||||||
from litellm.utils import supports_response_schema
|
from litellm.utils import supports_response_schema
|
||||||
|
|
||||||
|
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from crewai.utilities.events import crewai_event_bus
|
from crewai.utilities.events import crewai_event_bus
|
||||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||||
LLMContextLengthExceededException,
|
LLMContextLengthExceededException,
|
||||||
@@ -218,7 +219,7 @@ class StreamingChoices(TypedDict):
|
|||||||
finish_reason: Optional[str]
|
finish_reason: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM(BaseLLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|||||||
91
src/crewai/llms/base_llm.py
Normal file
91
src/crewai/llms/base_llm.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLM(ABC):
|
||||||
|
"""Abstract base class for LLM implementations.
|
||||||
|
|
||||||
|
This class defines the interface that all LLM implementations must follow.
|
||||||
|
Users can extend this class to create custom LLM implementations that don't
|
||||||
|
rely on litellm's authentication mechanism.
|
||||||
|
|
||||||
|
Custom LLM implementations should handle error cases gracefully, including
|
||||||
|
timeouts, authentication failures, and malformed responses. They should also
|
||||||
|
implement proper validation for input parameters and provide clear error
|
||||||
|
messages when things go wrong.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
stop (list): A list of stop sequences that the LLM should use to stop generation.
|
||||||
|
This is used by the CrewAgentExecutor and other components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the BaseLLM with default attributes.
|
||||||
|
|
||||||
|
This constructor sets default values for attributes that are expected
|
||||||
|
by the CrewAgentExecutor and other components.
|
||||||
|
|
||||||
|
All custom LLM implementations should call super().__init__() to ensure
|
||||||
|
that these default attributes are properly initialized.
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.temperature = temperature
|
||||||
|
self.stop = []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
"""Call the LLM with the given messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages for the LLM.
|
||||||
|
Can be a string or list of message dictionaries.
|
||||||
|
If string, it will be converted to a single user message.
|
||||||
|
If list, each dict must have 'role' and 'content' keys.
|
||||||
|
tools: Optional list of tool schemas for function calling.
|
||||||
|
Each tool should define its name, description, and parameters.
|
||||||
|
callbacks: Optional list of callback functions to be executed
|
||||||
|
during and after the LLM call.
|
||||||
|
available_functions: Optional dict mapping function names to callables
|
||||||
|
that can be invoked by the LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a text response from the LLM (str) or
|
||||||
|
the result of a tool function call (Any).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the messages format is invalid.
|
||||||
|
TimeoutError: If the LLM request times out.
|
||||||
|
RuntimeError: If the LLM request fails for other reasons.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Check if the LLM supports stop words.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the LLM supports stop words, False otherwise.
|
||||||
|
"""
|
||||||
|
return True # Default implementation assumes support for stop words
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Get the context window size for the LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The number of tokens/characters the model can handle.
|
||||||
|
"""
|
||||||
|
# Default implementation - subclasses should override with model-specific values
|
||||||
|
return 4096
|
||||||
38
src/crewai/llms/third_party/ai_suite.py
vendored
Normal file
38
src/crewai/llms/third_party/ai_suite.py
vendored
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import aisuite as ai
|
||||||
|
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
|
class AISuiteLLM(BaseLLM):
|
||||||
|
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs):
|
||||||
|
super().__init__(model, temperature, **kwargs)
|
||||||
|
self.client = ai.Client()
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
completion_params = self._prepare_completion_params(messages, tools)
|
||||||
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def _prepare_completion_params(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"tools": tools,
|
||||||
|
}
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
return False
|
||||||
@@ -6,7 +6,7 @@ from rich.console import Console
|
|||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
|
||||||
from crewai.agent import Agent
|
from crewai.agent import Agent
|
||||||
from crewai.llm import LLM
|
from crewai.llm import BaseLLM
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
from crewai.telemetry import Telemetry
|
from crewai.telemetry import Telemetry
|
||||||
@@ -24,7 +24,7 @@ class CrewEvaluator:
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
crew (Crew): The crew of agents to evaluate.
|
crew (Crew): The crew of agents to evaluate.
|
||||||
eval_llm (LLM): Language model instance to use for evaluations
|
eval_llm (BaseLLM): Language model instance to use for evaluations
|
||||||
tasks_scores (defaultdict): A dictionary to store the scores of the agents for each task.
|
tasks_scores (defaultdict): A dictionary to store the scores of the agents for each task.
|
||||||
iteration (int): The current iteration of the evaluation.
|
iteration (int): The current iteration of the evaluation.
|
||||||
"""
|
"""
|
||||||
@@ -33,7 +33,7 @@ class CrewEvaluator:
|
|||||||
run_execution_times: defaultdict = defaultdict(list)
|
run_execution_times: defaultdict = defaultdict(list)
|
||||||
iteration: int = 0
|
iteration: int = 0
|
||||||
|
|
||||||
def __init__(self, crew, eval_llm: InstanceOf[LLM]):
|
def __init__(self, crew, eval_llm: InstanceOf[BaseLLM]):
|
||||||
self.crew = crew
|
self.crew = crew
|
||||||
self.llm = eval_llm
|
self.llm = eval_llm
|
||||||
self._telemetry = Telemetry()
|
self._telemetry = Telemetry()
|
||||||
|
|||||||
@@ -2,28 +2,28 @@ import os
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM, BaseLLM
|
||||||
|
|
||||||
|
|
||||||
def create_llm(
|
def create_llm(
|
||||||
llm_value: Union[str, LLM, Any, None] = None,
|
llm_value: Union[str, LLM, Any, None] = None,
|
||||||
) -> Optional[LLM]:
|
) -> Optional[LLM | BaseLLM]:
|
||||||
"""
|
"""
|
||||||
Creates or returns an LLM instance based on the given llm_value.
|
Creates or returns an LLM instance based on the given llm_value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm_value (str | LLM | Any | None):
|
llm_value (str | BaseLLM | Any | None):
|
||||||
- str: The model name (e.g., "gpt-4").
|
- str: The model name (e.g., "gpt-4").
|
||||||
- LLM: Already instantiated LLM, returned as-is.
|
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
|
||||||
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
||||||
- None: Use environment-based or fallback default model.
|
- None: Use environment-based or fallback default model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An LLM instance if successful, or None if something fails.
|
A BaseLLM instance if successful, or None if something fails.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1) If llm_value is already an LLM object, return it directly
|
# 1) If llm_value is already a BaseLLM or LLM object, return it directly
|
||||||
if isinstance(llm_value, LLM):
|
if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM):
|
||||||
return llm_value
|
return llm_value
|
||||||
|
|
||||||
# 2) If llm_value is a string (model name)
|
# 2) If llm_value is a string (model name)
|
||||||
|
|||||||
107
tests/cassettes/test_custom_llm_implementation.yaml
Normal file
107
tests/cassettes/test_custom_llm_implementation.yaml
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
interactions:
|
||||||
|
- request:
|
||||||
|
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "What is the answer to life, the universe, and everything?"}],
|
||||||
|
"model": "gpt-4o-mini", "tools": null}'
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- application/json
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '206'
|
||||||
|
content-type:
|
||||||
|
- application/json
|
||||||
|
host:
|
||||||
|
- api.openai.com
|
||||||
|
user-agent:
|
||||||
|
- OpenAI/Python 1.61.0
|
||||||
|
x-stainless-arch:
|
||||||
|
- arm64
|
||||||
|
x-stainless-async:
|
||||||
|
- 'false'
|
||||||
|
x-stainless-lang:
|
||||||
|
- python
|
||||||
|
x-stainless-os:
|
||||||
|
- MacOS
|
||||||
|
x-stainless-package-version:
|
||||||
|
- 1.61.0
|
||||||
|
x-stainless-retry-count:
|
||||||
|
- '0'
|
||||||
|
x-stainless-runtime:
|
||||||
|
- CPython
|
||||||
|
x-stainless-runtime-version:
|
||||||
|
- 3.12.8
|
||||||
|
method: POST
|
||||||
|
uri: https://api.openai.com/v1/chat/completions
|
||||||
|
response:
|
||||||
|
content: "{\n \"id\": \"chatcmpl-B7W6FS0wpfndLdg12G3H6ZAXcYhJi\",\n \"object\":
|
||||||
|
\"chat.completion\",\n \"created\": 1741131387,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
|
||||||
|
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||||
|
\"assistant\",\n \"content\": \"The answer to life, the universe, and
|
||||||
|
everything, famously found in Douglas Adams' \\\"The Hitchhiker's Guide to the
|
||||||
|
Galaxy,\\\" is the number 42. However, the question itself is left ambiguous,
|
||||||
|
leading to much speculation and humor in the story.\",\n \"refusal\":
|
||||||
|
null\n },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n
|
||||||
|
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 30,\n \"completion_tokens\":
|
||||||
|
54,\n \"total_tokens\": 84,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||||
|
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n
|
||||||
|
\ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||||
|
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||||
|
\"default\",\n \"system_fingerprint\": \"fp_06737a9306\"\n}\n"
|
||||||
|
headers:
|
||||||
|
CF-RAY:
|
||||||
|
- 91b532234c18cf1f-SJC
|
||||||
|
Connection:
|
||||||
|
- keep-alive
|
||||||
|
Content-Encoding:
|
||||||
|
- gzip
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Tue, 04 Mar 2025 23:36:28 GMT
|
||||||
|
Server:
|
||||||
|
- cloudflare
|
||||||
|
Set-Cookie:
|
||||||
|
- __cf_bm=DgLb6UAE6W4Oeto1Bi2RiKXQVV5TTzkXdXWFdmAEwQQ-1741131388-1.0.1.1-jWQtsT95wOeQbmIxAK7cv8gJWxYi1tQ.IupuJzBDnZr7iEChwVUQBRfnYUBJPDsNly3bakCDArjD_S.FLKwH6xUfvlxgfd4YSBhBPy7bcgw;
|
||||||
|
path=/; expires=Wed, 05-Mar-25 00:06:28 GMT; domain=.api.openai.com; HttpOnly;
|
||||||
|
Secure; SameSite=None
|
||||||
|
- _cfuvid=Oa59XCmqjKLKwU34la1hkTunN57JW20E.ZHojvRBfow-1741131388236-0.0.1.1-604800000;
|
||||||
|
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||||
|
Transfer-Encoding:
|
||||||
|
- chunked
|
||||||
|
X-Content-Type-Options:
|
||||||
|
- nosniff
|
||||||
|
access-control-expose-headers:
|
||||||
|
- X-Request-ID
|
||||||
|
alt-svc:
|
||||||
|
- h3=":443"; ma=86400
|
||||||
|
cf-cache-status:
|
||||||
|
- DYNAMIC
|
||||||
|
openai-organization:
|
||||||
|
- crewai-iuxna1
|
||||||
|
openai-processing-ms:
|
||||||
|
- '776'
|
||||||
|
openai-version:
|
||||||
|
- '2020-10-01'
|
||||||
|
strict-transport-security:
|
||||||
|
- max-age=31536000; includeSubDomains; preload
|
||||||
|
x-ratelimit-limit-requests:
|
||||||
|
- '30000'
|
||||||
|
x-ratelimit-limit-tokens:
|
||||||
|
- '150000000'
|
||||||
|
x-ratelimit-remaining-requests:
|
||||||
|
- '29999'
|
||||||
|
x-ratelimit-remaining-tokens:
|
||||||
|
- '149999960'
|
||||||
|
x-ratelimit-reset-requests:
|
||||||
|
- 2ms
|
||||||
|
x-ratelimit-reset-tokens:
|
||||||
|
- 0s
|
||||||
|
x-request-id:
|
||||||
|
- req_97824e8fe7c1aca3fbcba7c925388b39
|
||||||
|
http_version: HTTP/1.1
|
||||||
|
status_code: 200
|
||||||
|
version: 1
|
||||||
305
tests/cassettes/test_custom_llm_within_crew.yaml
Normal file
305
tests/cassettes/test_custom_llm_within_crew.yaml
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
interactions:
|
||||||
|
- request:
|
||||||
|
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": [{"role": "system", "content": "You are Say Hi.
|
||||||
|
You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give
|
||||||
|
my best complete final answer to the task respond using the exact following
|
||||||
|
format:\n\nThought: I now can give a great answer\nFinal Answer: Your final
|
||||||
|
answer must be the great and the most complete as possible, it must be outcome
|
||||||
|
described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user",
|
||||||
|
"content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria
|
||||||
|
for your final answer: A greeting to the user\nyou MUST return the actual complete
|
||||||
|
content as the final answer, not a summary.\n\nBegin! This is VERY important
|
||||||
|
to you, use the tools available and give your best Final Answer, your job depends
|
||||||
|
on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}'
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- application/json
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '931'
|
||||||
|
content-type:
|
||||||
|
- application/json
|
||||||
|
host:
|
||||||
|
- api.openai.com
|
||||||
|
user-agent:
|
||||||
|
- OpenAI/Python 1.61.0
|
||||||
|
x-stainless-arch:
|
||||||
|
- arm64
|
||||||
|
x-stainless-async:
|
||||||
|
- 'false'
|
||||||
|
x-stainless-lang:
|
||||||
|
- python
|
||||||
|
x-stainless-os:
|
||||||
|
- MacOS
|
||||||
|
x-stainless-package-version:
|
||||||
|
- 1.61.0
|
||||||
|
x-stainless-retry-count:
|
||||||
|
- '0'
|
||||||
|
x-stainless-runtime:
|
||||||
|
- CPython
|
||||||
|
x-stainless-runtime-version:
|
||||||
|
- 3.12.8
|
||||||
|
method: POST
|
||||||
|
uri: https://api.openai.com/v1/chat/completions
|
||||||
|
response:
|
||||||
|
content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n
|
||||||
|
\ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n
|
||||||
|
\ \"code\": \"missing_required_parameter\"\n }\n}"
|
||||||
|
headers:
|
||||||
|
CF-RAY:
|
||||||
|
- 91b54660799a15b4-SJC
|
||||||
|
Connection:
|
||||||
|
- keep-alive
|
||||||
|
Content-Length:
|
||||||
|
- '219'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Tue, 04 Mar 2025 23:50:16 GMT
|
||||||
|
Server:
|
||||||
|
- cloudflare
|
||||||
|
Set-Cookie:
|
||||||
|
- __cf_bm=OwS.6cyfDpbxxx8vPp4THv5eNoDMQK0qSVN.wSUyOYk-1741132216-1.0.1.1-QBVd08CjfmDBpNnYQM5ILGbTUWKh6SDM9E4ARG4SV2Z9Q4ltFSFLXoo38OGJApUNZmzn4PtRsyAPsHt_dsrHPF6MD17FPcGtrnAHqCjJrfU;
|
||||||
|
path=/; expires=Wed, 05-Mar-25 00:20:16 GMT; domain=.api.openai.com; HttpOnly;
|
||||||
|
Secure; SameSite=None
|
||||||
|
- _cfuvid=n_ebDsAOhJm5Mc7OMx8JDiOaZq5qzHCnVxyS3KN0BwA-1741132216951-0.0.1.1-604800000;
|
||||||
|
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||||
|
X-Content-Type-Options:
|
||||||
|
- nosniff
|
||||||
|
access-control-expose-headers:
|
||||||
|
- X-Request-ID
|
||||||
|
alt-svc:
|
||||||
|
- h3=":443"; ma=86400
|
||||||
|
cf-cache-status:
|
||||||
|
- DYNAMIC
|
||||||
|
openai-organization:
|
||||||
|
- crewai-iuxna1
|
||||||
|
openai-processing-ms:
|
||||||
|
- '19'
|
||||||
|
openai-version:
|
||||||
|
- '2020-10-01'
|
||||||
|
strict-transport-security:
|
||||||
|
- max-age=31536000; includeSubDomains; preload
|
||||||
|
x-ratelimit-limit-requests:
|
||||||
|
- '30000'
|
||||||
|
x-ratelimit-limit-tokens:
|
||||||
|
- '150000000'
|
||||||
|
x-ratelimit-remaining-requests:
|
||||||
|
- '29999'
|
||||||
|
x-ratelimit-remaining-tokens:
|
||||||
|
- '149999974'
|
||||||
|
x-ratelimit-reset-requests:
|
||||||
|
- 2ms
|
||||||
|
x-ratelimit-reset-tokens:
|
||||||
|
- 0s
|
||||||
|
x-request-id:
|
||||||
|
- req_042a4e8f9432f6fde7a02037bb6caafa
|
||||||
|
http_version: HTTP/1.1
|
||||||
|
status_code: 400
|
||||||
|
- request:
|
||||||
|
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": [{"role": "system", "content": "You are Say Hi.
|
||||||
|
You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give
|
||||||
|
my best complete final answer to the task respond using the exact following
|
||||||
|
format:\n\nThought: I now can give a great answer\nFinal Answer: Your final
|
||||||
|
answer must be the great and the most complete as possible, it must be outcome
|
||||||
|
described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user",
|
||||||
|
"content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria
|
||||||
|
for your final answer: A greeting to the user\nyou MUST return the actual complete
|
||||||
|
content as the final answer, not a summary.\n\nBegin! This is VERY important
|
||||||
|
to you, use the tools available and give your best Final Answer, your job depends
|
||||||
|
on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}'
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- application/json
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '931'
|
||||||
|
content-type:
|
||||||
|
- application/json
|
||||||
|
host:
|
||||||
|
- api.openai.com
|
||||||
|
user-agent:
|
||||||
|
- OpenAI/Python 1.61.0
|
||||||
|
x-stainless-arch:
|
||||||
|
- arm64
|
||||||
|
x-stainless-async:
|
||||||
|
- 'false'
|
||||||
|
x-stainless-lang:
|
||||||
|
- python
|
||||||
|
x-stainless-os:
|
||||||
|
- MacOS
|
||||||
|
x-stainless-package-version:
|
||||||
|
- 1.61.0
|
||||||
|
x-stainless-retry-count:
|
||||||
|
- '0'
|
||||||
|
x-stainless-runtime:
|
||||||
|
- CPython
|
||||||
|
x-stainless-runtime-version:
|
||||||
|
- 3.12.8
|
||||||
|
method: POST
|
||||||
|
uri: https://api.openai.com/v1/chat/completions
|
||||||
|
response:
|
||||||
|
content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n
|
||||||
|
\ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n
|
||||||
|
\ \"code\": \"missing_required_parameter\"\n }\n}"
|
||||||
|
headers:
|
||||||
|
CF-RAY:
|
||||||
|
- 91b54664bb1acef1-SJC
|
||||||
|
Connection:
|
||||||
|
- keep-alive
|
||||||
|
Content-Length:
|
||||||
|
- '219'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Tue, 04 Mar 2025 23:50:17 GMT
|
||||||
|
Server:
|
||||||
|
- cloudflare
|
||||||
|
Set-Cookie:
|
||||||
|
- __cf_bm=.wGU4pJEajaSzFWjp05TBQwWbCNA2CgpYNu7UYOzbbM-1741132217-1.0.1.1-NoLiAx4qkplllldYYxZCOSQGsX6hsPUJIEyqmt84B3g7hjW1s7.jk9C9PYzXagHWjT0sQ9Ny4LZBA94lDJTfDBZpty8NJQha7ZKW0P_msH8;
|
||||||
|
path=/; expires=Wed, 05-Mar-25 00:20:17 GMT; domain=.api.openai.com; HttpOnly;
|
||||||
|
Secure; SameSite=None
|
||||||
|
- _cfuvid=GAjgJjVLtN49bMeWdWZDYLLkEkK51z5kxK4nKqhAzxY-1741132217161-0.0.1.1-604800000;
|
||||||
|
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||||
|
X-Content-Type-Options:
|
||||||
|
- nosniff
|
||||||
|
access-control-expose-headers:
|
||||||
|
- X-Request-ID
|
||||||
|
alt-svc:
|
||||||
|
- h3=":443"; ma=86400
|
||||||
|
cf-cache-status:
|
||||||
|
- DYNAMIC
|
||||||
|
openai-organization:
|
||||||
|
- crewai-iuxna1
|
||||||
|
openai-processing-ms:
|
||||||
|
- '25'
|
||||||
|
openai-version:
|
||||||
|
- '2020-10-01'
|
||||||
|
strict-transport-security:
|
||||||
|
- max-age=31536000; includeSubDomains; preload
|
||||||
|
x-ratelimit-limit-requests:
|
||||||
|
- '30000'
|
||||||
|
x-ratelimit-limit-tokens:
|
||||||
|
- '150000000'
|
||||||
|
x-ratelimit-remaining-requests:
|
||||||
|
- '29999'
|
||||||
|
x-ratelimit-remaining-tokens:
|
||||||
|
- '149999974'
|
||||||
|
x-ratelimit-reset-requests:
|
||||||
|
- 2ms
|
||||||
|
x-ratelimit-reset-tokens:
|
||||||
|
- 0s
|
||||||
|
x-request-id:
|
||||||
|
- req_7a1d027da1ef4468e861e570c72e98fb
|
||||||
|
http_version: HTTP/1.1
|
||||||
|
status_code: 400
|
||||||
|
- request:
|
||||||
|
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": [{"role": "system", "content": "You are Say Hi.
|
||||||
|
You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give
|
||||||
|
my best complete final answer to the task respond using the exact following
|
||||||
|
format:\n\nThought: I now can give a great answer\nFinal Answer: Your final
|
||||||
|
answer must be the great and the most complete as possible, it must be outcome
|
||||||
|
described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user",
|
||||||
|
"content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria
|
||||||
|
for your final answer: A greeting to the user\nyou MUST return the actual complete
|
||||||
|
content as the final answer, not a summary.\n\nBegin! This is VERY important
|
||||||
|
to you, use the tools available and give your best Final Answer, your job depends
|
||||||
|
on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}'
|
||||||
|
headers:
|
||||||
|
accept:
|
||||||
|
- application/json
|
||||||
|
accept-encoding:
|
||||||
|
- gzip, deflate
|
||||||
|
connection:
|
||||||
|
- keep-alive
|
||||||
|
content-length:
|
||||||
|
- '931'
|
||||||
|
content-type:
|
||||||
|
- application/json
|
||||||
|
host:
|
||||||
|
- api.openai.com
|
||||||
|
user-agent:
|
||||||
|
- OpenAI/Python 1.61.0
|
||||||
|
x-stainless-arch:
|
||||||
|
- arm64
|
||||||
|
x-stainless-async:
|
||||||
|
- 'false'
|
||||||
|
x-stainless-lang:
|
||||||
|
- python
|
||||||
|
x-stainless-os:
|
||||||
|
- MacOS
|
||||||
|
x-stainless-package-version:
|
||||||
|
- 1.61.0
|
||||||
|
x-stainless-retry-count:
|
||||||
|
- '0'
|
||||||
|
x-stainless-runtime:
|
||||||
|
- CPython
|
||||||
|
x-stainless-runtime-version:
|
||||||
|
- 3.12.8
|
||||||
|
method: POST
|
||||||
|
uri: https://api.openai.com/v1/chat/completions
|
||||||
|
response:
|
||||||
|
content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n
|
||||||
|
\ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n
|
||||||
|
\ \"code\": \"missing_required_parameter\"\n }\n}"
|
||||||
|
headers:
|
||||||
|
CF-RAY:
|
||||||
|
- 91b54666183beb22-SJC
|
||||||
|
Connection:
|
||||||
|
- keep-alive
|
||||||
|
Content-Length:
|
||||||
|
- '219'
|
||||||
|
Content-Type:
|
||||||
|
- application/json
|
||||||
|
Date:
|
||||||
|
- Tue, 04 Mar 2025 23:50:17 GMT
|
||||||
|
Server:
|
||||||
|
- cloudflare
|
||||||
|
Set-Cookie:
|
||||||
|
- __cf_bm=VwjWHHpkZMJlosI9RbMqxYDBS1t0JK4tWpAy4lST2QM-1741132217-1.0.1.1-u7PU.ZvVBTXNB5R8vaYfWdPXAjWZ3ZcTAy656VaGDZmKIckk5od._eQdn0W0EGVtEMm3TuF60z4GZAPDwMYvb3_3cw1RuEMmQbp4IIrl7VY;
|
||||||
|
path=/; expires=Wed, 05-Mar-25 00:20:17 GMT; domain=.api.openai.com; HttpOnly;
|
||||||
|
Secure; SameSite=None
|
||||||
|
- _cfuvid=NglAAsQBoiabMuuHFgilRjflSPFqS38VGKnGyweuCuw-1741132217438-0.0.1.1-604800000;
|
||||||
|
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
|
||||||
|
X-Content-Type-Options:
|
||||||
|
- nosniff
|
||||||
|
access-control-expose-headers:
|
||||||
|
- X-Request-ID
|
||||||
|
alt-svc:
|
||||||
|
- h3=":443"; ma=86400
|
||||||
|
cf-cache-status:
|
||||||
|
- DYNAMIC
|
||||||
|
openai-organization:
|
||||||
|
- crewai-iuxna1
|
||||||
|
openai-processing-ms:
|
||||||
|
- '56'
|
||||||
|
openai-version:
|
||||||
|
- '2020-10-01'
|
||||||
|
strict-transport-security:
|
||||||
|
- max-age=31536000; includeSubDomains; preload
|
||||||
|
x-ratelimit-limit-requests:
|
||||||
|
- '30000'
|
||||||
|
x-ratelimit-limit-tokens:
|
||||||
|
- '150000000'
|
||||||
|
x-ratelimit-remaining-requests:
|
||||||
|
- '29999'
|
||||||
|
x-ratelimit-remaining-tokens:
|
||||||
|
- '149999974'
|
||||||
|
x-ratelimit-reset-requests:
|
||||||
|
- 2ms
|
||||||
|
x-ratelimit-reset-tokens:
|
||||||
|
- 0s
|
||||||
|
x-request-id:
|
||||||
|
- req_3c335b308b82cc2214783a4bf2fc0fd4
|
||||||
|
http_version: HTTP/1.1
|
||||||
|
status_code: 400
|
||||||
|
version: 1
|
||||||
359
tests/custom_llm_test.py
Normal file
359
tests/custom_llm_test.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai import Agent, Crew, Process, Task
|
||||||
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
from crewai.utilities.llm_utils import create_llm
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLLM(BaseLLM):
|
||||||
|
"""Custom LLM implementation for testing.
|
||||||
|
|
||||||
|
This is a simple implementation of the BaseLLM abstract base class
|
||||||
|
that returns a predefined response for testing purposes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, response="Default response", model="test-model"):
|
||||||
|
"""Initialize the CustomLLM with a predefined response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The predefined response to return from call().
|
||||||
|
"""
|
||||||
|
super().__init__(model=model)
|
||||||
|
self.response = response
|
||||||
|
self.call_count = 0
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages,
|
||||||
|
tools=None,
|
||||||
|
callbacks=None,
|
||||||
|
available_functions=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Mock LLM call that returns a predefined response.
|
||||||
|
Properly formats messages to match OpenAI's expected structure.
|
||||||
|
"""
|
||||||
|
self.call_count += 1
|
||||||
|
|
||||||
|
# If input is a string, convert to proper message format
|
||||||
|
if isinstance(messages, str):
|
||||||
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
# Ensure each message has properly formatted content
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
message["content"] = [{"type": "text", "text": message["content"]}]
|
||||||
|
|
||||||
|
# Return predefined response in expected format
|
||||||
|
if "Thought:" in str(messages):
|
||||||
|
return f"Thought: I will say hi\nFinal Answer: {self.response}"
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Return False to indicate that function calling is not supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
False, indicating that this LLM does not support function calling.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Return False to indicate that stop words are not supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
False, indicating that this LLM does not support stop words.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Return a default context window size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
4096, a typical context window size for modern LLMs.
|
||||||
|
"""
|
||||||
|
return 4096
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_custom_llm_implementation():
|
||||||
|
"""Test that a custom LLM implementation works with create_llm."""
|
||||||
|
custom_llm = CustomLLM(response="The answer is 42")
|
||||||
|
|
||||||
|
# Test that create_llm returns the custom LLM instance directly
|
||||||
|
result_llm = create_llm(custom_llm)
|
||||||
|
|
||||||
|
assert result_llm is custom_llm
|
||||||
|
|
||||||
|
# Test calling the custom LLM
|
||||||
|
response = result_llm.call(
|
||||||
|
"What is the answer to life, the universe, and everything?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that the response from the custom LLM was used
|
||||||
|
assert "42" in response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_custom_llm_within_crew():
|
||||||
|
"""Test that a custom LLM implementation works with create_llm."""
|
||||||
|
custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Say Hi",
|
||||||
|
goal="Say hi to the user",
|
||||||
|
backstory="""You just say hi to the user""",
|
||||||
|
llm=custom_llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Say hi to the user",
|
||||||
|
expected_output="A greeting to the user",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = Crew(
|
||||||
|
agents=[agent],
|
||||||
|
tasks=[task],
|
||||||
|
process=Process.sequential,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = crew.kickoff()
|
||||||
|
|
||||||
|
# Assert the LLM was called
|
||||||
|
assert custom_llm.call_count > 0
|
||||||
|
# Assert we got a response
|
||||||
|
assert "Hello!" in result.raw
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_llm_message_formatting():
|
||||||
|
"""Test that the custom LLM properly formats messages"""
|
||||||
|
custom_llm = CustomLLM(response="Test response", model="test-model")
|
||||||
|
|
||||||
|
# Test with string input
|
||||||
|
result = custom_llm.call("Test message")
|
||||||
|
assert result == "Test response"
|
||||||
|
|
||||||
|
# Test with message list
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "System message"},
|
||||||
|
{"role": "user", "content": "User message"},
|
||||||
|
]
|
||||||
|
result = custom_llm.call(messages)
|
||||||
|
assert result == "Test response"
|
||||||
|
|
||||||
|
|
||||||
|
class JWTAuthLLM(BaseLLM):
|
||||||
|
"""Custom LLM implementation with JWT authentication."""
|
||||||
|
|
||||||
|
def __init__(self, jwt_token: str):
|
||||||
|
super().__init__(model="test-model")
|
||||||
|
if not jwt_token or not isinstance(jwt_token, str):
|
||||||
|
raise ValueError("Invalid JWT token")
|
||||||
|
self.jwt_token = jwt_token
|
||||||
|
self.calls = []
|
||||||
|
self.stop = []
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
"""Record the call and return a predefined response."""
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
"available_functions": available_functions,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# In a real implementation, this would use the JWT token to authenticate
|
||||||
|
# with an external service
|
||||||
|
return "Response from JWT-authenticated LLM"
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Return True to indicate that function calling is supported."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Return True to indicate that stop words are supported."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Return a default context window size."""
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_llm_with_jwt_auth():
|
||||||
|
"""Test a custom LLM implementation with JWT authentication."""
|
||||||
|
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
|
||||||
|
|
||||||
|
# Test that create_llm returns the JWT-authenticated LLM instance directly
|
||||||
|
result_llm = create_llm(jwt_llm)
|
||||||
|
|
||||||
|
assert result_llm is jwt_llm
|
||||||
|
|
||||||
|
# Test calling the JWT-authenticated LLM
|
||||||
|
response = result_llm.call("Test message")
|
||||||
|
|
||||||
|
# Verify that the JWT-authenticated LLM was called
|
||||||
|
assert len(jwt_llm.calls) > 0
|
||||||
|
# Verify that the response from the JWT-authenticated LLM was used
|
||||||
|
assert response == "Response from JWT-authenticated LLM"
|
||||||
|
|
||||||
|
|
||||||
|
def test_jwt_auth_llm_validation():
|
||||||
|
"""Test that JWT token validation works correctly."""
|
||||||
|
# Test with invalid JWT token (empty string)
|
||||||
|
with pytest.raises(ValueError, match="Invalid JWT token"):
|
||||||
|
JWTAuthLLM(jwt_token="")
|
||||||
|
|
||||||
|
# Test with invalid JWT token (non-string)
|
||||||
|
with pytest.raises(ValueError, match="Invalid JWT token"):
|
||||||
|
JWTAuthLLM(jwt_token=None)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeoutHandlingLLM(BaseLLM):
|
||||||
|
"""Custom LLM implementation with timeout handling and retry logic."""
|
||||||
|
|
||||||
|
def __init__(self, max_retries: int = 3, timeout: int = 30):
|
||||||
|
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_retries: Maximum number of retry attempts.
|
||||||
|
timeout: Timeout in seconds for each API call.
|
||||||
|
"""
|
||||||
|
super().__init__(model="test-model")
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.timeout = timeout
|
||||||
|
self.calls = []
|
||||||
|
self.stop = []
|
||||||
|
self.fail_count = 0 # Number of times to simulate failure
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
"""Simulate API calls with timeout handling and retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages for the LLM.
|
||||||
|
tools: Optional list of tool schemas for function calling.
|
||||||
|
callbacks: Optional list of callback functions.
|
||||||
|
available_functions: Optional dict mapping function names to callables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A response string based on whether this is the first attempt or a retry.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If all retry attempts fail.
|
||||||
|
"""
|
||||||
|
# Record the initial call
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
"available_functions": available_functions,
|
||||||
|
"attempt": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulate retry logic
|
||||||
|
for attempt in range(self.max_retries):
|
||||||
|
# Skip the first attempt recording since we already did that above
|
||||||
|
if attempt == 0:
|
||||||
|
# Simulate a failure if fail_count > 0
|
||||||
|
if self.fail_count > 0:
|
||||||
|
self.fail_count -= 1
|
||||||
|
# If we've used all retries, raise an error
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"LLM request failed after {self.max_retries} attempts"
|
||||||
|
)
|
||||||
|
# Otherwise, continue to the next attempt (simulating backoff)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Success on first attempt
|
||||||
|
return "First attempt response"
|
||||||
|
else:
|
||||||
|
# This is a retry attempt (attempt > 0)
|
||||||
|
# Always record retry attempts
|
||||||
|
self.calls.append(
|
||||||
|
{
|
||||||
|
"retry_attempt": attempt,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
"available_functions": available_functions,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulate a failure if fail_count > 0
|
||||||
|
if self.fail_count > 0:
|
||||||
|
self.fail_count -= 1
|
||||||
|
# If we've used all retries, raise an error
|
||||||
|
if attempt == self.max_retries - 1:
|
||||||
|
raise TimeoutError(
|
||||||
|
f"LLM request failed after {self.max_retries} attempts"
|
||||||
|
)
|
||||||
|
# Otherwise, continue to the next attempt (simulating backoff)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Success on retry
|
||||||
|
return "Response after retry"
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Return True to indicate that function calling is supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True, indicating that this LLM supports function calling.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Return True to indicate that stop words are supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True, indicating that this LLM supports stop words.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Return a default context window size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
8192, a typical context window size for modern LLMs.
|
||||||
|
"""
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
|
||||||
|
def test_timeout_handling_llm():
|
||||||
|
"""Test a custom LLM implementation with timeout handling and retry logic."""
|
||||||
|
# Test successful first attempt
|
||||||
|
llm = TimeoutHandlingLLM()
|
||||||
|
response = llm.call("Test message")
|
||||||
|
assert response == "First attempt response"
|
||||||
|
assert len(llm.calls) == 1
|
||||||
|
|
||||||
|
# Test successful retry
|
||||||
|
llm = TimeoutHandlingLLM()
|
||||||
|
llm.fail_count = 1 # Fail once, then succeed
|
||||||
|
response = llm.call("Test message")
|
||||||
|
assert response == "Response after retry"
|
||||||
|
assert len(llm.calls) == 2 # Initial call + successful retry call
|
||||||
|
|
||||||
|
# Test failure after all retries
|
||||||
|
llm = TimeoutHandlingLLM(max_retries=2)
|
||||||
|
llm.fail_count = 2 # Fail twice, which is all retries
|
||||||
|
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
|
||||||
|
llm.call("Test message")
|
||||||
|
assert len(llm.calls) == 2 # Initial call + failed retry attempt
|
||||||
16
uv.lock
generated
16
uv.lock
generated
@@ -139,6 +139,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17", size = 7617 },
|
{ url = "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17", size = 7617 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aisuite"
|
||||||
|
version = "0.1.10"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "httpx" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/6a/9d/c7a8a76abb9011dd2bc9a5cb8ffa8231640e20bbdae177ce9ab6cb67c66c/aisuite-0.1.10.tar.gz", hash = "sha256:170e62d4c91fecb22e82a04e058154a111cef473681171e5df7346272e77f414", size = 29052 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/58/c2/9a34a01516de107e5f9406dbfd319b6004340708101d67fa107373da4058/aisuite-0.1.10-py3-none-any.whl", hash = "sha256:c8510ebe38d6546b6a06819171e201fcaf0bf9ae020ffcfe19b6bd90430781ad", size = 43984 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "alembic"
|
name = "alembic"
|
||||||
version = "1.13.3"
|
version = "1.13.3"
|
||||||
@@ -651,6 +663,9 @@ dependencies = [
|
|||||||
agentops = [
|
agentops = [
|
||||||
{ name = "agentops" },
|
{ name = "agentops" },
|
||||||
]
|
]
|
||||||
|
aisuite = [
|
||||||
|
{ name = "aisuite" },
|
||||||
|
]
|
||||||
docling = [
|
docling = [
|
||||||
{ name = "docling" },
|
{ name = "docling" },
|
||||||
]
|
]
|
||||||
@@ -698,6 +713,7 @@ dev = [
|
|||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "agentops", marker = "extra == 'agentops'", specifier = ">=0.3.0" },
|
{ name = "agentops", marker = "extra == 'agentops'", specifier = ">=0.3.0" },
|
||||||
|
{ name = "aisuite", marker = "extra == 'aisuite'", specifier = ">=0.1.10" },
|
||||||
{ name = "appdirs", specifier = ">=1.4.4" },
|
{ name = "appdirs", specifier = ">=1.4.4" },
|
||||||
{ name = "auth0-python", specifier = ">=4.7.1" },
|
{ name = "auth0-python", specifier = ">=4.7.1" },
|
||||||
{ name = "blinker", specifier = ">=1.9.0" },
|
{ name = "blinker", specifier = ">=1.9.0" },
|
||||||
|
|||||||
Reference in New Issue
Block a user