Simplify LLM implementation by consolidating LLM and BaseLLM classes

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-03-14 06:35:42 +00:00
parent 963ed23b63
commit 1b8c07760e
5 changed files with 218 additions and 36 deletions

View File

@@ -1,25 +1,44 @@
# Custom LLM Implementations
CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
CrewAI supports custom LLM implementations through the `LLM` base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
## Using Custom LLM Implementations
To create a custom LLM implementation, you need to:
1. Inherit from the `BaseLLM` abstract base class
1. Inherit from the `LLM` base class
2. Implement the required methods:
- `call()`: The main method to call the LLM with messages
- `supports_function_calling()`: Whether the LLM supports function calling
- `supports_stop_words()`: Whether the LLM supports stop words
- `get_context_window_size()`: The context window size of the LLM
## Using the Default LLM Implementation
If you don't need a custom LLM implementation, you can use the default implementation provided by CrewAI:
```python
from crewai import LLM
# Create a default LLM instance
llm = LLM.create(model="gpt-4")
# Or with more parameters
llm = LLM.create(
model="gpt-4",
temperature=0.7,
max_tokens=1000,
api_key="your-api-key"
)
```
## Example: Basic Custom LLM
```python
from crewai import BaseLLM
from crewai import LLM
from typing import Any, Dict, List, Optional, Union
class CustomLLM(BaseLLM):
class CustomLLM(LLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes
if not api_key or not isinstance(api_key, str):
@@ -230,10 +249,10 @@ def call(
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
```python
from crewai import BaseLLM, Agent, Task
from crewai import LLM, Agent, Task
from typing import Any, Dict, List, Optional, Union
class JWTAuthLLM(BaseLLM):
class JWTAuthLLM(LLM):
def __init__(self, jwt_token: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes
if not jwt_token or not isinstance(jwt_token, str):
@@ -631,7 +650,7 @@ print(result)
## Implementing Your Own Authentication Mechanism
The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
The `LLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
- OAuth tokens
- Client certificates
@@ -640,3 +659,23 @@ The `BaseLLM` class allows you to implement any authentication mechanism you nee
- Any other authentication method required by your LLM provider
Simply implement the appropriate authentication logic in your custom LLM class.
## Migrating from BaseLLM to LLM
If you were previously using `BaseLLM`, you can simply replace it with `LLM`:
```python
# Old code
from crewai import BaseLLM
class CustomLLM(BaseLLM):
# ...
# New code
from crewai import LLM
class CustomLLM(LLM):
# ...
```
The `BaseLLM` class is still available for backward compatibility but will be removed in a future release. It now inherits from `LLM` and emits a deprecation warning when instantiated.

View File

@@ -4,7 +4,7 @@ from crewai.agent import Agent
from crewai.crew import Crew
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM, BaseLLM, DefaultLLM
from crewai.process import Process
from crewai.task import Task
@@ -22,6 +22,7 @@ __all__ = [
"Task",
"LLM",
"BaseLLM",
"DefaultLLM",
"Flow",
"Knowledge",
]

View File

@@ -35,8 +35,8 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
load_dotenv()
class BaseLLM(ABC):
"""Abstract base class for LLM implementations.
class LLM(ABC):
"""Base class for LLM implementations.
This class defines the interface that all LLM implementations must follow.
Users can extend this class to create custom LLM implementations that don't
@@ -52,8 +52,26 @@ class BaseLLM(ABC):
This is used by the CrewAgentExecutor and other components.
"""
def __new__(cls, *args, **kwargs):
"""Create a new LLM instance.
This method handles backward compatibility by creating a DefaultLLM instance
when the LLM class is instantiated directly with parameters.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
Either a new LLM instance or a DefaultLLM instance for backward compatibility.
"""
if cls is LLM and (args or kwargs.get('model') is not None):
from crewai.llm import DefaultLLM
return DefaultLLM(*args, **kwargs)
return super().__new__(cls)
def __init__(self):
"""Initialize the BaseLLM with default attributes.
"""Initialize the LLM with default attributes.
This constructor sets default values for attributes that are expected
by the CrewAgentExecutor and other components.
@@ -63,7 +81,91 @@ class BaseLLM(ABC):
"""
self.stop = []
@abstractmethod
@classmethod
def create(
cls,
model: str,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
**kwargs,
) -> 'DefaultLLM':
"""Create a default LLM instance using litellm.
This factory method creates a default LLM instance using litellm as the backend.
It's the recommended way to create LLM instances for most users.
Args:
model: The model name (e.g., "gpt-4").
timeout: Optional timeout for the LLM call.
temperature: Optional temperature for the LLM call.
top_p: Optional top_p for the LLM call.
n: Optional n for the LLM call.
stop: Optional stop sequences for the LLM call.
max_completion_tokens: Optional max_completion_tokens for the LLM call.
max_tokens: Optional max_tokens for the LLM call.
presence_penalty: Optional presence_penalty for the LLM call.
frequency_penalty: Optional frequency_penalty for the LLM call.
logit_bias: Optional logit_bias for the LLM call.
response_format: Optional response_format for the LLM call.
seed: Optional seed for the LLM call.
logprobs: Optional logprobs for the LLM call.
top_logprobs: Optional top_logprobs for the LLM call.
base_url: Optional base_url for the LLM call.
api_base: Optional api_base for the LLM call.
api_version: Optional api_version for the LLM call.
api_key: Optional api_key for the LLM call.
callbacks: Optional callbacks for the LLM call.
reasoning_effort: Optional reasoning_effort for the LLM call.
**kwargs: Additional keyword arguments for the LLM call.
Returns:
A DefaultLLM instance configured with the provided parameters.
"""
from crewai.llm import DefaultLLM
return DefaultLLM(
model=model,
timeout=timeout,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
response_format=response_format,
seed=seed,
logprobs=logprobs,
top_logprobs=top_logprobs,
base_url=base_url,
api_base=api_base,
api_version=api_version,
api_key=api_key,
callbacks=callbacks,
reasoning_effort=reasoning_effort,
**kwargs,
)
def call(
self,
messages: Union[str, List[Dict[str, str]]],
@@ -93,10 +195,10 @@ class BaseLLM(ABC):
ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
NotImplementedError: If this method is not implemented by a subclass.
"""
pass
raise NotImplementedError("Subclasses must implement call()")
@abstractmethod
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
@@ -107,10 +209,12 @@ class BaseLLM(ABC):
Returns:
True if the LLM supports function calling, False otherwise.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
"""
pass
raise NotImplementedError("Subclasses must implement supports_function_calling()")
@abstractmethod
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
@@ -120,10 +224,12 @@ class BaseLLM(ABC):
Returns:
True if the LLM supports stop words, False otherwise.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
"""
pass
raise NotImplementedError("Subclasses must implement supports_stop_words()")
@abstractmethod
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
@@ -133,8 +239,11 @@ class BaseLLM(ABC):
Returns:
The context window size as an integer.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
"""
pass
raise NotImplementedError("Subclasses must implement get_context_window_size()")
class FilteredStream:
@@ -229,7 +338,14 @@ def suppress_warnings():
sys.stderr = old_stderr
class LLM(BaseLLM):
class DefaultLLM(LLM):
"""Default LLM implementation using litellm.
This class provides a concrete implementation of the LLM interface
using litellm as the backend. It's the default implementation used
by CrewAI when no custom LLM is provided.
"""
def __init__(
self,
model: str,
@@ -255,6 +371,8 @@ class LLM(BaseLLM):
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
**kwargs,
):
super().__init__() # Initialize the base class
self.model = model
self.timeout = timeout
self.temperature = temperature
@@ -283,7 +401,7 @@ class LLM(BaseLLM):
# Normalize self.stop to always be a List[str]
if stop is None:
self.stop: List[str] = []
self.stop = [] # Already initialized in base class
elif isinstance(stop, str):
self.stop = [stop]
else:
@@ -667,3 +785,27 @@ class LLM(BaseLLM):
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
class BaseLLM(LLM):
"""Deprecated: Use LLM instead.
This class is kept for backward compatibility and will be removed in a future release.
It inherits from LLM and provides the same interface, but emits a deprecation warning
when instantiated.
"""
def __init__(self):
"""Initialize the BaseLLM with a deprecation warning.
This constructor emits a deprecation warning and then calls the parent class's
constructor to initialize the LLM.
"""
import warnings
warnings.warn(
"BaseLLM is deprecated and will be removed in a future release. "
"Use LLM instead for custom implementations.",
DeprecationWarning,
stacklevel=2
)
super().__init__()

View File

@@ -6,30 +6,30 @@ from crewai.llm import LLM, BaseLLM
def create_llm(
llm_value: Union[str, BaseLLM, Any, None] = None,
) -> Optional[BaseLLM]:
llm_value: Union[str, LLM, Any, None] = None,
) -> Optional[LLM]:
"""
Creates or returns an LLM instance based on the given llm_value.
Args:
llm_value (str | BaseLLM | Any | None):
llm_value (str | LLM | Any | None):
- str: The model name (e.g., "gpt-4").
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
- LLM: Already instantiated LLM, returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
Returns:
A BaseLLM instance if successful, or None if something fails.
A LLM instance if successful, or None if something fails.
"""
# 1) If llm_value is already a BaseLLM object, return it directly
if isinstance(llm_value, BaseLLM):
# 1) If llm_value is already a LLM object, return it directly
if isinstance(llm_value, LLM):
return llm_value
# 2) If llm_value is a string (model name)
if isinstance(llm_value, str):
try:
created_llm = LLM(model=llm_value)
created_llm = LLM.create(model=llm_value)
return created_llm
except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
@@ -56,7 +56,7 @@ def create_llm(
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)
created_llm = LLM(
created_llm = LLM.create(
model=model,
temperature=temperature,
max_tokens=max_tokens,
@@ -175,7 +175,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
# Try creating the LLM
try:
new_llm = LLM(**llm_params)
new_llm = LLM.create(**llm_params)
return new_llm
except Exception as e:
print(

View File

@@ -2,14 +2,14 @@ from typing import Any, Dict, List, Optional, Union
import pytest
from crewai.llm import BaseLLM
from crewai.llm import LLM
from crewai.utilities.llm_utils import create_llm
class CustomLLM(BaseLLM):
class CustomLLM(LLM):
"""Custom LLM implementation for testing.
This is a simple implementation of the BaseLLM abstract base class
This is a simple implementation of the LLM abstract base class
that returns a predefined response for testing purposes.
"""
@@ -93,7 +93,7 @@ def test_custom_llm_implementation():
assert response == "The answer is 42"
class JWTAuthLLM(BaseLLM):
class JWTAuthLLM(LLM):
"""Custom LLM implementation with JWT authentication."""
def __init__(self, jwt_token: str):
@@ -164,7 +164,7 @@ def test_jwt_auth_llm_validation():
JWTAuthLLM(jwt_token=None)
class TimeoutHandlingLLM(BaseLLM):
class TimeoutHandlingLLM(LLM):
"""Custom LLM implementation with timeout handling and retry logic."""
def __init__(self, max_retries: int = 3, timeout: int = 30):