Enhance BaseLLM documentation and add model parameter validation

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-19 14:53:18 +00:00
parent 8c4f6e3db9
commit 2266980274
2 changed files with 58 additions and 8 deletions

View File

@@ -18,6 +18,15 @@ To create a custom LLM implementation, you need to:
- `get_context_window_size()`: The context window size of the LLM - `get_context_window_size()`: The context window size of the LLM
3. Ensure you pass a model identifier string to the BaseLLM constructor using `super().__init__(model="your-model-name")` 3. Ensure you pass a model identifier string to the BaseLLM constructor using `super().__init__(model="your-model-name")`
## Required Parameters
When creating custom LLM implementations, the following parameters are essential:
- `model`: String identifier for your model implementation.
- Required in the BaseLLM constructor
- Example values: `"gpt-4-custom"`, `"anthropic-claude-custom"`, `"custom-llm-v1.0"`
- Used to identify the model in logs, metrics, and other components
## Example: Basic Custom LLM ## Example: Basic Custom LLM
```python ```python
@@ -25,8 +34,14 @@ from crewai import BaseLLM
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
class CustomLLM(BaseLLM): class CustomLLM(BaseLLM):
"""A custom LLM implementation with basic API key authentication.
Args:
api_key (str): API key for the LLM service.
endpoint (str): Endpoint URL for the LLM service.
"""
def __init__(self, api_key: str, endpoint: str): def __init__(self, api_key: str, endpoint: str):
super().__init__(model="custom-model") # Initialize with required model parameter super().__init__(model="custom-llm-v1.0") # Initialize with required model parameter
if not api_key or not isinstance(api_key, str): if not api_key or not isinstance(api_key, str):
raise ValueError("Invalid API key: must be a non-empty string") raise ValueError("Invalid API key: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str): if not endpoint or not isinstance(endpoint, str):
@@ -196,7 +211,7 @@ Always validate input parameters to prevent runtime errors:
```python ```python
def __init__(self, api_key: str, endpoint: str): def __init__(self, api_key: str, endpoint: str):
super().__init__(model="custom-model") # Initialize with required model parameter super().__init__(model="custom-api-llm-v1.0") # Initialize with required model parameter
if not api_key or not isinstance(api_key, str): if not api_key or not isinstance(api_key, str):
raise ValueError("Invalid API key: must be a non-empty string") raise ValueError("Invalid API key: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str): if not endpoint or not isinstance(endpoint, str):
@@ -239,8 +254,14 @@ from crewai import BaseLLM, Agent, Task
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
class JWTAuthLLM(BaseLLM): class JWTAuthLLM(BaseLLM):
"""A custom LLM implementation with JWT-based authentication.
Args:
jwt_token (str): JWT token for authentication.
endpoint (str): Endpoint URL for the LLM service.
"""
def __init__(self, jwt_token: str, endpoint: str): def __init__(self, jwt_token: str, endpoint: str):
super().__init__(model="custom-jwt-model") # Initialize with required model parameter super().__init__(model="jwt-auth-llm-v1.0") # Initialize with required model parameter
if not jwt_token or not isinstance(jwt_token, str): if not jwt_token or not isinstance(jwt_token, str):
raise ValueError("Invalid JWT token: must be a non-empty string") raise ValueError("Invalid JWT token: must be a non-empty string")
if not endpoint or not isinstance(endpoint, str): if not endpoint or not isinstance(endpoint, str):
@@ -387,8 +408,14 @@ import logging
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
class LoggingLLM(BaseLLM): class LoggingLLM(BaseLLM):
"""A custom LLM implementation with logging capabilities.
Args:
api_key (str): API key for the LLM service.
endpoint (str): Endpoint URL for the LLM service.
"""
def __init__(self, api_key: str, endpoint: str): def __init__(self, api_key: str, endpoint: str):
super().__init__(model="custom-logging-model") # Initialize with required model parameter super().__init__(model="logging-llm-v1.0") # Initialize with required model parameter
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.logger = logging.getLogger("crewai.llm.custom") self.logger = logging.getLogger("crewai.llm.custom")
@@ -420,13 +447,21 @@ import time
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
class RateLimitedLLM(BaseLLM): class RateLimitedLLM(BaseLLM):
"""A custom LLM implementation with rate limiting capabilities.
Args:
api_key (str): API key for the LLM service.
endpoint (str): Endpoint URL for the LLM service.
requests_per_minute (int, optional): Maximum number of requests allowed per minute.
Defaults to 60.
"""
def __init__( def __init__(
self, self,
api_key: str, api_key: str,
endpoint: str, endpoint: str,
requests_per_minute: int = 60 requests_per_minute: int = 60
): ):
super().__init__(model="custom-rate-limited-model") # Initialize with required model parameter super().__init__(model="rate-limited-llm-v1.0") # Initialize with required model parameter
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.requests_per_minute = requests_per_minute self.requests_per_minute = requests_per_minute
@@ -468,8 +503,14 @@ import time
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
class MetricsCollectingLLM(BaseLLM): class MetricsCollectingLLM(BaseLLM):
"""A custom LLM implementation with metrics collection capabilities.
Args:
api_key (str): API key for the LLM service.
endpoint (str): Endpoint URL for the LLM service.
"""
def __init__(self, api_key: str, endpoint: str): def __init__(self, api_key: str, endpoint: str):
super().__init__(model="custom-metrics-model") # Initialize with required model parameter super().__init__(model="metrics-llm-v1.0") # Initialize with required model parameter
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.metrics: Dict[str, Any] = { self.metrics: Dict[str, Any] = {

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
class BaseLLM(ABC): class BaseLLM(ABC):
@@ -27,7 +27,7 @@ class BaseLLM(ABC):
self, self,
model: str, model: str,
temperature: Optional[float] = None, temperature: Optional[float] = None,
): ) -> None:
"""Initialize the BaseLLM with default attributes. """Initialize the BaseLLM with default attributes.
This constructor sets default values for attributes that are expected This constructor sets default values for attributes that are expected
@@ -36,7 +36,16 @@ class BaseLLM(ABC):
All custom LLM implementations should call super().__init__(model="model_name"), All custom LLM implementations should call super().__init__(model="model_name"),
where "model_name" is a string identifier for your model. This parameter where "model_name" is a string identifier for your model. This parameter
is required and cannot be omitted. is required and cannot be omitted.
Args:
model (str): Required. A string identifier for the model.
temperature (Optional[float]): The sampling temperature to use.
Raises:
ValueError: If the model parameter is not provided or empty.
""" """
if not model:
raise ValueError("model parameter is required and must be a non-empty string")
self.model = model self.model = model
self.temperature = temperature self.temperature = temperature
self.stop = [] self.stop = []