mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Simplify LLM implementation by consolidating LLM and BaseLLM classes
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,25 +1,44 @@
|
||||
# Custom LLM Implementations
|
||||
|
||||
CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
|
||||
CrewAI supports custom LLM implementations through the `LLM` base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
|
||||
|
||||
## Using Custom LLM Implementations
|
||||
|
||||
To create a custom LLM implementation, you need to:
|
||||
|
||||
1. Inherit from the `BaseLLM` abstract base class
|
||||
1. Inherit from the `LLM` base class
|
||||
2. Implement the required methods:
|
||||
- `call()`: The main method to call the LLM with messages
|
||||
- `supports_function_calling()`: Whether the LLM supports function calling
|
||||
- `supports_stop_words()`: Whether the LLM supports stop words
|
||||
- `get_context_window_size()`: The context window size of the LLM
|
||||
|
||||
## Using the Default LLM Implementation
|
||||
|
||||
If you don't need a custom LLM implementation, you can use the default implementation provided by CrewAI:
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
# Create a default LLM instance
|
||||
llm = LLM.create(model="gpt-4")
|
||||
|
||||
# Or with more parameters
|
||||
llm = LLM.create(
|
||||
model="gpt-4",
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
api_key="your-api-key"
|
||||
)
|
||||
```
|
||||
|
||||
## Example: Basic Custom LLM
|
||||
|
||||
```python
|
||||
from crewai import BaseLLM
|
||||
from crewai import LLM
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class CustomLLM(BaseLLM):
|
||||
class CustomLLM(LLM):
|
||||
def __init__(self, api_key: str, endpoint: str):
|
||||
super().__init__() # Initialize the base class to set default attributes
|
||||
if not api_key or not isinstance(api_key, str):
|
||||
@@ -230,10 +249,10 @@ def call(
|
||||
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
|
||||
|
||||
```python
|
||||
from crewai import BaseLLM, Agent, Task
|
||||
from crewai import LLM, Agent, Task
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
class JWTAuthLLM(BaseLLM):
|
||||
class JWTAuthLLM(LLM):
|
||||
def __init__(self, jwt_token: str, endpoint: str):
|
||||
super().__init__() # Initialize the base class to set default attributes
|
||||
if not jwt_token or not isinstance(jwt_token, str):
|
||||
@@ -631,7 +650,7 @@ print(result)
|
||||
|
||||
## Implementing Your Own Authentication Mechanism
|
||||
|
||||
The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
|
||||
The `LLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
|
||||
|
||||
- OAuth tokens
|
||||
- Client certificates
|
||||
@@ -640,3 +659,23 @@ The `BaseLLM` class allows you to implement any authentication mechanism you nee
|
||||
- Any other authentication method required by your LLM provider
|
||||
|
||||
Simply implement the appropriate authentication logic in your custom LLM class.
|
||||
|
||||
## Migrating from BaseLLM to LLM
|
||||
|
||||
If you were previously using `BaseLLM`, you can simply replace it with `LLM`:
|
||||
|
||||
```python
|
||||
# Old code
|
||||
from crewai import BaseLLM
|
||||
|
||||
class CustomLLM(BaseLLM):
|
||||
# ...
|
||||
|
||||
# New code
|
||||
from crewai import LLM
|
||||
|
||||
class CustomLLM(LLM):
|
||||
# ...
|
||||
```
|
||||
|
||||
The `BaseLLM` class is still available for backward compatibility but will be removed in a future release. It now inherits from `LLM` and emits a deprecation warning when instantiated.
|
||||
|
||||
@@ -4,7 +4,7 @@ from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM, BaseLLM, DefaultLLM
|
||||
from crewai.process import Process
|
||||
from crewai.task import Task
|
||||
|
||||
@@ -22,6 +22,7 @@ __all__ = [
|
||||
"Task",
|
||||
"LLM",
|
||||
"BaseLLM",
|
||||
"DefaultLLM",
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
]
|
||||
|
||||
@@ -35,8 +35,8 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""Abstract base class for LLM implementations.
|
||||
class LLM(ABC):
|
||||
"""Base class for LLM implementations.
|
||||
|
||||
This class defines the interface that all LLM implementations must follow.
|
||||
Users can extend this class to create custom LLM implementations that don't
|
||||
@@ -52,8 +52,26 @@ class BaseLLM(ABC):
|
||||
This is used by the CrewAgentExecutor and other components.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Create a new LLM instance.
|
||||
|
||||
This method handles backward compatibility by creating a DefaultLLM instance
|
||||
when the LLM class is instantiated directly with parameters.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
Either a new LLM instance or a DefaultLLM instance for backward compatibility.
|
||||
"""
|
||||
if cls is LLM and (args or kwargs.get('model') is not None):
|
||||
from crewai.llm import DefaultLLM
|
||||
return DefaultLLM(*args, **kwargs)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
"""Initialize the LLM with default attributes.
|
||||
|
||||
This constructor sets default values for attributes that are expected
|
||||
by the CrewAgentExecutor and other components.
|
||||
@@ -63,7 +81,91 @@ class BaseLLM(ABC):
|
||||
"""
|
||||
self.stop = []
|
||||
|
||||
@abstractmethod
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model: str,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
**kwargs,
|
||||
) -> 'DefaultLLM':
|
||||
"""Create a default LLM instance using litellm.
|
||||
|
||||
This factory method creates a default LLM instance using litellm as the backend.
|
||||
It's the recommended way to create LLM instances for most users.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "gpt-4").
|
||||
timeout: Optional timeout for the LLM call.
|
||||
temperature: Optional temperature for the LLM call.
|
||||
top_p: Optional top_p for the LLM call.
|
||||
n: Optional n for the LLM call.
|
||||
stop: Optional stop sequences for the LLM call.
|
||||
max_completion_tokens: Optional max_completion_tokens for the LLM call.
|
||||
max_tokens: Optional max_tokens for the LLM call.
|
||||
presence_penalty: Optional presence_penalty for the LLM call.
|
||||
frequency_penalty: Optional frequency_penalty for the LLM call.
|
||||
logit_bias: Optional logit_bias for the LLM call.
|
||||
response_format: Optional response_format for the LLM call.
|
||||
seed: Optional seed for the LLM call.
|
||||
logprobs: Optional logprobs for the LLM call.
|
||||
top_logprobs: Optional top_logprobs for the LLM call.
|
||||
base_url: Optional base_url for the LLM call.
|
||||
api_base: Optional api_base for the LLM call.
|
||||
api_version: Optional api_version for the LLM call.
|
||||
api_key: Optional api_key for the LLM call.
|
||||
callbacks: Optional callbacks for the LLM call.
|
||||
reasoning_effort: Optional reasoning_effort for the LLM call.
|
||||
**kwargs: Additional keyword arguments for the LLM call.
|
||||
|
||||
Returns:
|
||||
A DefaultLLM instance configured with the provided parameters.
|
||||
"""
|
||||
from crewai.llm import DefaultLLM
|
||||
|
||||
return DefaultLLM(
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
callbacks=callbacks,
|
||||
reasoning_effort=reasoning_effort,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
@@ -93,10 +195,10 @@ class BaseLLM(ABC):
|
||||
ValueError: If the messages format is invalid.
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement call()")
|
||||
|
||||
@abstractmethod
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
@@ -107,10 +209,12 @@ class BaseLLM(ABC):
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement supports_function_calling()")
|
||||
|
||||
@abstractmethod
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
@@ -120,10 +224,12 @@ class BaseLLM(ABC):
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement supports_stop_words()")
|
||||
|
||||
@abstractmethod
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size of the LLM.
|
||||
|
||||
@@ -133,8 +239,11 @@ class BaseLLM(ABC):
|
||||
|
||||
Returns:
|
||||
The context window size as an integer.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If this method is not implemented by a subclass.
|
||||
"""
|
||||
pass
|
||||
raise NotImplementedError("Subclasses must implement get_context_window_size()")
|
||||
|
||||
|
||||
class FilteredStream:
|
||||
@@ -229,7 +338,14 @@ def suppress_warnings():
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
class DefaultLLM(LLM):
|
||||
"""Default LLM implementation using litellm.
|
||||
|
||||
This class provides a concrete implementation of the LLM interface
|
||||
using litellm as the backend. It's the default implementation used
|
||||
by CrewAI when no custom LLM is provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@@ -255,6 +371,8 @@ class LLM(BaseLLM):
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__() # Initialize the base class
|
||||
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
@@ -283,7 +401,7 @@ class LLM(BaseLLM):
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
self.stop = [] # Already initialized in base class
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
@@ -667,3 +785,27 @@ class LLM(BaseLLM):
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
|
||||
class BaseLLM(LLM):
|
||||
"""Deprecated: Use LLM instead.
|
||||
|
||||
This class is kept for backward compatibility and will be removed in a future release.
|
||||
It inherits from LLM and provides the same interface, but emits a deprecation warning
|
||||
when instantiated.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the BaseLLM with a deprecation warning.
|
||||
|
||||
This constructor emits a deprecation warning and then calls the parent class's
|
||||
constructor to initialize the LLM.
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"BaseLLM is deprecated and will be removed in a future release. "
|
||||
"Use LLM instead for custom implementations.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
@@ -6,30 +6,30 @@ from crewai.llm import LLM, BaseLLM
|
||||
|
||||
|
||||
def create_llm(
|
||||
llm_value: Union[str, BaseLLM, Any, None] = None,
|
||||
) -> Optional[BaseLLM]:
|
||||
llm_value: Union[str, LLM, Any, None] = None,
|
||||
) -> Optional[LLM]:
|
||||
"""
|
||||
Creates or returns an LLM instance based on the given llm_value.
|
||||
|
||||
Args:
|
||||
llm_value (str | BaseLLM | Any | None):
|
||||
llm_value (str | LLM | Any | None):
|
||||
- str: The model name (e.g., "gpt-4").
|
||||
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
|
||||
- LLM: Already instantiated LLM, returned as-is.
|
||||
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
||||
- None: Use environment-based or fallback default model.
|
||||
|
||||
Returns:
|
||||
A BaseLLM instance if successful, or None if something fails.
|
||||
A LLM instance if successful, or None if something fails.
|
||||
"""
|
||||
|
||||
# 1) If llm_value is already a BaseLLM object, return it directly
|
||||
if isinstance(llm_value, BaseLLM):
|
||||
# 1) If llm_value is already a LLM object, return it directly
|
||||
if isinstance(llm_value, LLM):
|
||||
return llm_value
|
||||
|
||||
# 2) If llm_value is a string (model name)
|
||||
if isinstance(llm_value, str):
|
||||
try:
|
||||
created_llm = LLM(model=llm_value)
|
||||
created_llm = LLM.create(model=llm_value)
|
||||
return created_llm
|
||||
except Exception as e:
|
||||
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
|
||||
@@ -56,7 +56,7 @@ def create_llm(
|
||||
base_url: Optional[str] = getattr(llm_value, "base_url", None)
|
||||
api_base: Optional[str] = getattr(llm_value, "api_base", None)
|
||||
|
||||
created_llm = LLM(
|
||||
created_llm = LLM.create(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
@@ -175,7 +175,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
|
||||
# Try creating the LLM
|
||||
try:
|
||||
new_llm = LLM(**llm_params)
|
||||
new_llm = LLM.create(**llm_params)
|
||||
return new_llm
|
||||
except Exception as e:
|
||||
print(
|
||||
|
||||
@@ -2,14 +2,14 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
|
||||
|
||||
class CustomLLM(BaseLLM):
|
||||
class CustomLLM(LLM):
|
||||
"""Custom LLM implementation for testing.
|
||||
|
||||
This is a simple implementation of the BaseLLM abstract base class
|
||||
This is a simple implementation of the LLM abstract base class
|
||||
that returns a predefined response for testing purposes.
|
||||
"""
|
||||
|
||||
@@ -93,7 +93,7 @@ def test_custom_llm_implementation():
|
||||
assert response == "The answer is 42"
|
||||
|
||||
|
||||
class JWTAuthLLM(BaseLLM):
|
||||
class JWTAuthLLM(LLM):
|
||||
"""Custom LLM implementation with JWT authentication."""
|
||||
|
||||
def __init__(self, jwt_token: str):
|
||||
@@ -164,7 +164,7 @@ def test_jwt_auth_llm_validation():
|
||||
JWTAuthLLM(jwt_token=None)
|
||||
|
||||
|
||||
class TimeoutHandlingLLM(BaseLLM):
|
||||
class TimeoutHandlingLLM(LLM):
|
||||
"""Custom LLM implementation with timeout handling and retry logic."""
|
||||
|
||||
def __init__(self, max_retries: int = 3, timeout: int = 30):
|
||||
|
||||
Reference in New Issue
Block a user