Simplify LLM implementation by consolidating LLM and BaseLLM classes

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-03-14 06:35:42 +00:00
parent 963ed23b63
commit 1b8c07760e
5 changed files with 218 additions and 36 deletions

View File

@@ -1,25 +1,44 @@
# Custom LLM Implementations # Custom LLM Implementations
CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism. CrewAI supports custom LLM implementations through the `LLM` base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
## Using Custom LLM Implementations ## Using Custom LLM Implementations
To create a custom LLM implementation, you need to: To create a custom LLM implementation, you need to:
1. Inherit from the `BaseLLM` abstract base class 1. Inherit from the `LLM` base class
2. Implement the required methods: 2. Implement the required methods:
- `call()`: The main method to call the LLM with messages - `call()`: The main method to call the LLM with messages
- `supports_function_calling()`: Whether the LLM supports function calling - `supports_function_calling()`: Whether the LLM supports function calling
- `supports_stop_words()`: Whether the LLM supports stop words - `supports_stop_words()`: Whether the LLM supports stop words
- `get_context_window_size()`: The context window size of the LLM - `get_context_window_size()`: The context window size of the LLM
## Using the Default LLM Implementation
If you don't need a custom LLM implementation, you can use the default implementation provided by CrewAI:
```python
from crewai import LLM
# Create a default LLM instance
llm = LLM.create(model="gpt-4")
# Or with more parameters
llm = LLM.create(
model="gpt-4",
temperature=0.7,
max_tokens=1000,
api_key="your-api-key"
)
```
## Example: Basic Custom LLM ## Example: Basic Custom LLM
```python ```python
from crewai import BaseLLM from crewai import LLM
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
class CustomLLM(BaseLLM): class CustomLLM(LLM):
def __init__(self, api_key: str, endpoint: str): def __init__(self, api_key: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes super().__init__() # Initialize the base class to set default attributes
if not api_key or not isinstance(api_key, str): if not api_key or not isinstance(api_key, str):
@@ -230,10 +249,10 @@ def call(
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this: For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
```python ```python
from crewai import BaseLLM, Agent, Task from crewai import LLM, Agent, Task
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
class JWTAuthLLM(BaseLLM): class JWTAuthLLM(LLM):
def __init__(self, jwt_token: str, endpoint: str): def __init__(self, jwt_token: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes super().__init__() # Initialize the base class to set default attributes
if not jwt_token or not isinstance(jwt_token, str): if not jwt_token or not isinstance(jwt_token, str):
@@ -631,7 +650,7 @@ print(result)
## Implementing Your Own Authentication Mechanism ## Implementing Your Own Authentication Mechanism
The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use: The `LLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
- OAuth tokens - OAuth tokens
- Client certificates - Client certificates
@@ -640,3 +659,23 @@ The `BaseLLM` class allows you to implement any authentication mechanism you nee
- Any other authentication method required by your LLM provider - Any other authentication method required by your LLM provider
Simply implement the appropriate authentication logic in your custom LLM class. Simply implement the appropriate authentication logic in your custom LLM class.
## Migrating from BaseLLM to LLM
If you were previously using `BaseLLM`, you can simply replace it with `LLM`:
```python
# Old code
from crewai import BaseLLM
class CustomLLM(BaseLLM):
# ...
# New code
from crewai import LLM
class CustomLLM(LLM):
# ...
```
The `BaseLLM` class is still available for backward compatibility but will be removed in a future release. It now inherits from `LLM` and emits a deprecation warning when instantiated.

View File

@@ -4,7 +4,7 @@ from crewai.agent import Agent
from crewai.crew import Crew from crewai.crew import Crew
from crewai.flow.flow import Flow from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM, BaseLLM from crewai.llm import LLM, BaseLLM, DefaultLLM
from crewai.process import Process from crewai.process import Process
from crewai.task import Task from crewai.task import Task
@@ -22,6 +22,7 @@ __all__ = [
"Task", "Task",
"LLM", "LLM",
"BaseLLM", "BaseLLM",
"DefaultLLM",
"Flow", "Flow",
"Knowledge", "Knowledge",
] ]

View File

@@ -35,8 +35,8 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
load_dotenv() load_dotenv()
class BaseLLM(ABC): class LLM(ABC):
"""Abstract base class for LLM implementations. """Base class for LLM implementations.
This class defines the interface that all LLM implementations must follow. This class defines the interface that all LLM implementations must follow.
Users can extend this class to create custom LLM implementations that don't Users can extend this class to create custom LLM implementations that don't
@@ -52,8 +52,26 @@ class BaseLLM(ABC):
This is used by the CrewAgentExecutor and other components. This is used by the CrewAgentExecutor and other components.
""" """
def __new__(cls, *args, **kwargs):
"""Create a new LLM instance.
This method handles backward compatibility by creating a DefaultLLM instance
when the LLM class is instantiated directly with parameters.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
Either a new LLM instance or a DefaultLLM instance for backward compatibility.
"""
if cls is LLM and (args or kwargs.get('model') is not None):
from crewai.llm import DefaultLLM
return DefaultLLM(*args, **kwargs)
return super().__new__(cls)
def __init__(self): def __init__(self):
"""Initialize the BaseLLM with default attributes. """Initialize the LLM with default attributes.
This constructor sets default values for attributes that are expected This constructor sets default values for attributes that are expected
by the CrewAgentExecutor and other components. by the CrewAgentExecutor and other components.
@@ -63,7 +81,91 @@ class BaseLLM(ABC):
""" """
self.stop = [] self.stop = []
@abstractmethod @classmethod
def create(
cls,
model: str,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
**kwargs,
) -> 'DefaultLLM':
"""Create a default LLM instance using litellm.
This factory method creates a default LLM instance using litellm as the backend.
It's the recommended way to create LLM instances for most users.
Args:
model: The model name (e.g., "gpt-4").
timeout: Optional timeout for the LLM call.
temperature: Optional temperature for the LLM call.
top_p: Optional top_p for the LLM call.
n: Optional n for the LLM call.
stop: Optional stop sequences for the LLM call.
max_completion_tokens: Optional max_completion_tokens for the LLM call.
max_tokens: Optional max_tokens for the LLM call.
presence_penalty: Optional presence_penalty for the LLM call.
frequency_penalty: Optional frequency_penalty for the LLM call.
logit_bias: Optional logit_bias for the LLM call.
response_format: Optional response_format for the LLM call.
seed: Optional seed for the LLM call.
logprobs: Optional logprobs for the LLM call.
top_logprobs: Optional top_logprobs for the LLM call.
base_url: Optional base_url for the LLM call.
api_base: Optional api_base for the LLM call.
api_version: Optional api_version for the LLM call.
api_key: Optional api_key for the LLM call.
callbacks: Optional callbacks for the LLM call.
reasoning_effort: Optional reasoning_effort for the LLM call.
**kwargs: Additional keyword arguments for the LLM call.
Returns:
A DefaultLLM instance configured with the provided parameters.
"""
from crewai.llm import DefaultLLM
return DefaultLLM(
model=model,
timeout=timeout,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
response_format=response_format,
seed=seed,
logprobs=logprobs,
top_logprobs=top_logprobs,
base_url=base_url,
api_base=api_base,
api_version=api_version,
api_key=api_key,
callbacks=callbacks,
reasoning_effort=reasoning_effort,
**kwargs,
)
def call( def call(
self, self,
messages: Union[str, List[Dict[str, str]]], messages: Union[str, List[Dict[str, str]]],
@@ -93,10 +195,10 @@ class BaseLLM(ABC):
ValueError: If the messages format is invalid. ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out. TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons. RuntimeError: If the LLM request fails for other reasons.
NotImplementedError: If this method is not implemented by a subclass.
""" """
pass raise NotImplementedError("Subclasses must implement call()")
@abstractmethod
def supports_function_calling(self) -> bool: def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling. """Check if the LLM supports function calling.
@@ -107,10 +209,12 @@ class BaseLLM(ABC):
Returns: Returns:
True if the LLM supports function calling, False otherwise. True if the LLM supports function calling, False otherwise.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
""" """
pass raise NotImplementedError("Subclasses must implement supports_function_calling()")
@abstractmethod
def supports_stop_words(self) -> bool: def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words. """Check if the LLM supports stop words.
@@ -120,10 +224,12 @@ class BaseLLM(ABC):
Returns: Returns:
True if the LLM supports stop words, False otherwise. True if the LLM supports stop words, False otherwise.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
""" """
pass raise NotImplementedError("Subclasses must implement supports_stop_words()")
@abstractmethod
def get_context_window_size(self) -> int: def get_context_window_size(self) -> int:
"""Get the context window size of the LLM. """Get the context window size of the LLM.
@@ -133,8 +239,11 @@ class BaseLLM(ABC):
Returns: Returns:
The context window size as an integer. The context window size as an integer.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
""" """
pass raise NotImplementedError("Subclasses must implement get_context_window_size()")
class FilteredStream: class FilteredStream:
@@ -229,7 +338,14 @@ def suppress_warnings():
sys.stderr = old_stderr sys.stderr = old_stderr
class LLM(BaseLLM): class DefaultLLM(LLM):
"""Default LLM implementation using litellm.
This class provides a concrete implementation of the LLM interface
using litellm as the backend. It's the default implementation used
by CrewAI when no custom LLM is provided.
"""
def __init__( def __init__(
self, self,
model: str, model: str,
@@ -255,6 +371,8 @@ class LLM(BaseLLM):
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
**kwargs, **kwargs,
): ):
super().__init__() # Initialize the base class
self.model = model self.model = model
self.timeout = timeout self.timeout = timeout
self.temperature = temperature self.temperature = temperature
@@ -283,7 +401,7 @@ class LLM(BaseLLM):
# Normalize self.stop to always be a List[str] # Normalize self.stop to always be a List[str]
if stop is None: if stop is None:
self.stop: List[str] = [] self.stop = [] # Already initialized in base class
elif isinstance(stop, str): elif isinstance(stop, str):
self.stop = [stop] self.stop = [stop]
else: else:
@@ -667,3 +785,27 @@ class LLM(BaseLLM):
litellm.success_callback = success_callbacks litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks litellm.failure_callback = failure_callbacks
class BaseLLM(LLM):
"""Deprecated: Use LLM instead.
This class is kept for backward compatibility and will be removed in a future release.
It inherits from LLM and provides the same interface, but emits a deprecation warning
when instantiated.
"""
def __init__(self):
"""Initialize the BaseLLM with a deprecation warning.
This constructor emits a deprecation warning and then calls the parent class's
constructor to initialize the LLM.
"""
import warnings
warnings.warn(
"BaseLLM is deprecated and will be removed in a future release. "
"Use LLM instead for custom implementations.",
DeprecationWarning,
stacklevel=2
)
super().__init__()

View File

@@ -6,30 +6,30 @@ from crewai.llm import LLM, BaseLLM
def create_llm( def create_llm(
llm_value: Union[str, BaseLLM, Any, None] = None, llm_value: Union[str, LLM, Any, None] = None,
) -> Optional[BaseLLM]: ) -> Optional[LLM]:
""" """
Creates or returns an LLM instance based on the given llm_value. Creates or returns an LLM instance based on the given llm_value.
Args: Args:
llm_value (str | BaseLLM | Any | None): llm_value (str | LLM | Any | None):
- str: The model name (e.g., "gpt-4"). - str: The model name (e.g., "gpt-4").
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is. - LLM: Already instantiated LLM, returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc. - Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model. - None: Use environment-based or fallback default model.
Returns: Returns:
A BaseLLM instance if successful, or None if something fails. A LLM instance if successful, or None if something fails.
""" """
# 1) If llm_value is already a BaseLLM object, return it directly # 1) If llm_value is already a LLM object, return it directly
if isinstance(llm_value, BaseLLM): if isinstance(llm_value, LLM):
return llm_value return llm_value
# 2) If llm_value is a string (model name) # 2) If llm_value is a string (model name)
if isinstance(llm_value, str): if isinstance(llm_value, str):
try: try:
created_llm = LLM(model=llm_value) created_llm = LLM.create(model=llm_value)
return created_llm return created_llm
except Exception as e: except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}") print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
@@ -56,7 +56,7 @@ def create_llm(
base_url: Optional[str] = getattr(llm_value, "base_url", None) base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None) api_base: Optional[str] = getattr(llm_value, "api_base", None)
created_llm = LLM( created_llm = LLM.create(
model=model, model=model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
@@ -175,7 +175,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
# Try creating the LLM # Try creating the LLM
try: try:
new_llm = LLM(**llm_params) new_llm = LLM.create(**llm_params)
return new_llm return new_llm
except Exception as e: except Exception as e:
print( print(

View File

@@ -2,14 +2,14 @@ from typing import Any, Dict, List, Optional, Union
import pytest import pytest
from crewai.llm import BaseLLM from crewai.llm import LLM
from crewai.utilities.llm_utils import create_llm from crewai.utilities.llm_utils import create_llm
class CustomLLM(BaseLLM): class CustomLLM(LLM):
"""Custom LLM implementation for testing. """Custom LLM implementation for testing.
This is a simple implementation of the BaseLLM abstract base class This is a simple implementation of the LLM abstract base class
that returns a predefined response for testing purposes. that returns a predefined response for testing purposes.
""" """
@@ -93,7 +93,7 @@ def test_custom_llm_implementation():
assert response == "The answer is 42" assert response == "The answer is 42"
class JWTAuthLLM(BaseLLM): class JWTAuthLLM(LLM):
"""Custom LLM implementation with JWT authentication.""" """Custom LLM implementation with JWT authentication."""
def __init__(self, jwt_token: str): def __init__(self, jwt_token: str):
@@ -164,7 +164,7 @@ def test_jwt_auth_llm_validation():
JWTAuthLLM(jwt_token=None) JWTAuthLLM(jwt_token=None)
class TimeoutHandlingLLM(BaseLLM): class TimeoutHandlingLLM(LLM):
"""Custom LLM implementation with timeout handling and retry logic.""" """Custom LLM implementation with timeout handling and retry logic."""
def __init__(self, max_retries: int = 3, timeout: int = 30): def __init__(self, max_retries: int = 3, timeout: int = 30):