mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Enhance BaseLLM documentation and add model parameter validation
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -18,6 +18,15 @@ To create a custom LLM implementation, you need to:
|
|||||||
- `get_context_window_size()`: The context window size of the LLM
|
- `get_context_window_size()`: The context window size of the LLM
|
||||||
3. Ensure you pass a model identifier string to the BaseLLM constructor using `super().__init__(model="your-model-name")`
|
3. Ensure you pass a model identifier string to the BaseLLM constructor using `super().__init__(model="your-model-name")`
|
||||||
|
|
||||||
|
## Required Parameters
|
||||||
|
|
||||||
|
When creating custom LLM implementations, the following parameters are essential:
|
||||||
|
|
||||||
|
- `model`: String identifier for your model implementation.
|
||||||
|
- Required in the BaseLLM constructor
|
||||||
|
- Example values: `"gpt-4-custom"`, `"anthropic-claude-custom"`, `"custom-llm-v1.0"`
|
||||||
|
- Used to identify the model in logs, metrics, and other components
|
||||||
|
|
||||||
## Example: Basic Custom LLM
|
## Example: Basic Custom LLM
|
||||||
|
|
||||||
```python
|
```python
|
||||||
@@ -25,8 +34,14 @@ from crewai import BaseLLM
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
class CustomLLM(BaseLLM):
|
class CustomLLM(BaseLLM):
|
||||||
|
"""A custom LLM implementation with basic API key authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): API key for the LLM service.
|
||||||
|
endpoint (str): Endpoint URL for the LLM service.
|
||||||
|
"""
|
||||||
def __init__(self, api_key: str, endpoint: str):
|
def __init__(self, api_key: str, endpoint: str):
|
||||||
super().__init__(model="custom-model") # Initialize with required model parameter
|
super().__init__(model="custom-llm-v1.0") # Initialize with required model parameter
|
||||||
if not api_key or not isinstance(api_key, str):
|
if not api_key or not isinstance(api_key, str):
|
||||||
raise ValueError("Invalid API key: must be a non-empty string")
|
raise ValueError("Invalid API key: must be a non-empty string")
|
||||||
if not endpoint or not isinstance(endpoint, str):
|
if not endpoint or not isinstance(endpoint, str):
|
||||||
@@ -196,7 +211,7 @@ Always validate input parameters to prevent runtime errors:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
def __init__(self, api_key: str, endpoint: str):
|
def __init__(self, api_key: str, endpoint: str):
|
||||||
super().__init__(model="custom-model") # Initialize with required model parameter
|
super().__init__(model="custom-api-llm-v1.0") # Initialize with required model parameter
|
||||||
if not api_key or not isinstance(api_key, str):
|
if not api_key or not isinstance(api_key, str):
|
||||||
raise ValueError("Invalid API key: must be a non-empty string")
|
raise ValueError("Invalid API key: must be a non-empty string")
|
||||||
if not endpoint or not isinstance(endpoint, str):
|
if not endpoint or not isinstance(endpoint, str):
|
||||||
@@ -239,8 +254,14 @@ from crewai import BaseLLM, Agent, Task
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
class JWTAuthLLM(BaseLLM):
|
class JWTAuthLLM(BaseLLM):
|
||||||
|
"""A custom LLM implementation with JWT-based authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jwt_token (str): JWT token for authentication.
|
||||||
|
endpoint (str): Endpoint URL for the LLM service.
|
||||||
|
"""
|
||||||
def __init__(self, jwt_token: str, endpoint: str):
|
def __init__(self, jwt_token: str, endpoint: str):
|
||||||
super().__init__(model="custom-jwt-model") # Initialize with required model parameter
|
super().__init__(model="jwt-auth-llm-v1.0") # Initialize with required model parameter
|
||||||
if not jwt_token or not isinstance(jwt_token, str):
|
if not jwt_token or not isinstance(jwt_token, str):
|
||||||
raise ValueError("Invalid JWT token: must be a non-empty string")
|
raise ValueError("Invalid JWT token: must be a non-empty string")
|
||||||
if not endpoint or not isinstance(endpoint, str):
|
if not endpoint or not isinstance(endpoint, str):
|
||||||
@@ -387,8 +408,14 @@ import logging
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
class LoggingLLM(BaseLLM):
|
class LoggingLLM(BaseLLM):
|
||||||
|
"""A custom LLM implementation with logging capabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): API key for the LLM service.
|
||||||
|
endpoint (str): Endpoint URL for the LLM service.
|
||||||
|
"""
|
||||||
def __init__(self, api_key: str, endpoint: str):
|
def __init__(self, api_key: str, endpoint: str):
|
||||||
super().__init__(model="custom-logging-model") # Initialize with required model parameter
|
super().__init__(model="logging-llm-v1.0") # Initialize with required model parameter
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.logger = logging.getLogger("crewai.llm.custom")
|
self.logger = logging.getLogger("crewai.llm.custom")
|
||||||
@@ -420,13 +447,21 @@ import time
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
class RateLimitedLLM(BaseLLM):
|
class RateLimitedLLM(BaseLLM):
|
||||||
|
"""A custom LLM implementation with rate limiting capabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): API key for the LLM service.
|
||||||
|
endpoint (str): Endpoint URL for the LLM service.
|
||||||
|
requests_per_minute (int, optional): Maximum number of requests allowed per minute.
|
||||||
|
Defaults to 60.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
requests_per_minute: int = 60
|
requests_per_minute: int = 60
|
||||||
):
|
):
|
||||||
super().__init__(model="custom-rate-limited-model") # Initialize with required model parameter
|
super().__init__(model="rate-limited-llm-v1.0") # Initialize with required model parameter
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.requests_per_minute = requests_per_minute
|
self.requests_per_minute = requests_per_minute
|
||||||
@@ -468,8 +503,14 @@ import time
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
class MetricsCollectingLLM(BaseLLM):
|
class MetricsCollectingLLM(BaseLLM):
|
||||||
|
"""A custom LLM implementation with metrics collection capabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str): API key for the LLM service.
|
||||||
|
endpoint (str): Endpoint URL for the LLM service.
|
||||||
|
"""
|
||||||
def __init__(self, api_key: str, endpoint: str):
|
def __init__(self, api_key: str, endpoint: str):
|
||||||
super().__init__(model="custom-metrics-model") # Initialize with required model parameter
|
super().__init__(model="metrics-llm-v1.0") # Initialize with required model parameter
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.metrics: Dict[str, Any] = {
|
self.metrics: Dict[str, Any] = {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(ABC):
|
class BaseLLM(ABC):
|
||||||
@@ -27,7 +27,7 @@ class BaseLLM(ABC):
|
|||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Initialize the BaseLLM with default attributes.
|
"""Initialize the BaseLLM with default attributes.
|
||||||
|
|
||||||
This constructor sets default values for attributes that are expected
|
This constructor sets default values for attributes that are expected
|
||||||
@@ -36,7 +36,16 @@ class BaseLLM(ABC):
|
|||||||
All custom LLM implementations should call super().__init__(model="model_name"),
|
All custom LLM implementations should call super().__init__(model="model_name"),
|
||||||
where "model_name" is a string identifier for your model. This parameter
|
where "model_name" is a string identifier for your model. This parameter
|
||||||
is required and cannot be omitted.
|
is required and cannot be omitted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str): Required. A string identifier for the model.
|
||||||
|
temperature (Optional[float]): The sampling temperature to use.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the model parameter is not provided or empty.
|
||||||
"""
|
"""
|
||||||
|
if not model:
|
||||||
|
raise ValueError("model parameter is required and must be a non-empty string")
|
||||||
self.model = model
|
self.model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.stop = []
|
self.stop = []
|
||||||
|
|||||||
Reference in New Issue
Block a user