Simplify LLM implementation by consolidating LLM and BaseLLM classes

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-03-14 06:35:42 +00:00
parent 963ed23b63
commit 1b8c07760e
5 changed files with 218 additions and 36 deletions

View File

@@ -4,7 +4,7 @@ from crewai.agent import Agent
from crewai.crew import Crew
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM, BaseLLM, DefaultLLM
from crewai.process import Process
from crewai.task import Task
@@ -22,6 +22,7 @@ __all__ = [
"Task",
"LLM",
"BaseLLM",
"DefaultLLM",
"Flow",
"Knowledge",
]

View File

@@ -35,8 +35,8 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
load_dotenv()
class BaseLLM(ABC):
"""Abstract base class for LLM implementations.
class LLM(ABC):
"""Base class for LLM implementations.
This class defines the interface that all LLM implementations must follow.
Users can extend this class to create custom LLM implementations that don't
@@ -52,8 +52,26 @@ class BaseLLM(ABC):
This is used by the CrewAgentExecutor and other components.
"""
def __new__(cls, *args, **kwargs):
"""Create a new LLM instance.
This method handles backward compatibility by creating a DefaultLLM instance
when the LLM class is instantiated directly with parameters.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
Either a new LLM instance or a DefaultLLM instance for backward compatibility.
"""
if cls is LLM and (args or kwargs.get('model') is not None):
from crewai.llm import DefaultLLM
return DefaultLLM(*args, **kwargs)
return super().__new__(cls)
def __init__(self):
"""Initialize the BaseLLM with default attributes.
"""Initialize the LLM with default attributes.
This constructor sets default values for attributes that are expected
by the CrewAgentExecutor and other components.
@@ -63,7 +81,91 @@ class BaseLLM(ABC):
"""
self.stop = []
@abstractmethod
@classmethod
def create(
cls,
model: str,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
**kwargs,
) -> 'DefaultLLM':
"""Create a default LLM instance using litellm.
This factory method creates a default LLM instance using litellm as the backend.
It's the recommended way to create LLM instances for most users.
Args:
model: The model name (e.g., "gpt-4").
timeout: Optional timeout for the LLM call.
temperature: Optional temperature for the LLM call.
top_p: Optional top_p for the LLM call.
n: Optional n for the LLM call.
stop: Optional stop sequences for the LLM call.
max_completion_tokens: Optional max_completion_tokens for the LLM call.
max_tokens: Optional max_tokens for the LLM call.
presence_penalty: Optional presence_penalty for the LLM call.
frequency_penalty: Optional frequency_penalty for the LLM call.
logit_bias: Optional logit_bias for the LLM call.
response_format: Optional response_format for the LLM call.
seed: Optional seed for the LLM call.
logprobs: Optional logprobs for the LLM call.
top_logprobs: Optional top_logprobs for the LLM call.
base_url: Optional base_url for the LLM call.
api_base: Optional api_base for the LLM call.
api_version: Optional api_version for the LLM call.
api_key: Optional api_key for the LLM call.
callbacks: Optional callbacks for the LLM call.
reasoning_effort: Optional reasoning_effort for the LLM call.
**kwargs: Additional keyword arguments for the LLM call.
Returns:
A DefaultLLM instance configured with the provided parameters.
"""
from crewai.llm import DefaultLLM
return DefaultLLM(
model=model,
timeout=timeout,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
response_format=response_format,
seed=seed,
logprobs=logprobs,
top_logprobs=top_logprobs,
base_url=base_url,
api_base=api_base,
api_version=api_version,
api_key=api_key,
callbacks=callbacks,
reasoning_effort=reasoning_effort,
**kwargs,
)
def call(
self,
messages: Union[str, List[Dict[str, str]]],
@@ -93,10 +195,10 @@ class BaseLLM(ABC):
ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
NotImplementedError: If this method is not implemented by a subclass.
"""
pass
raise NotImplementedError("Subclasses must implement call()")
@abstractmethod
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
@@ -107,10 +209,12 @@ class BaseLLM(ABC):
Returns:
True if the LLM supports function calling, False otherwise.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
"""
pass
raise NotImplementedError("Subclasses must implement supports_function_calling()")
@abstractmethod
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
@@ -120,10 +224,12 @@ class BaseLLM(ABC):
Returns:
True if the LLM supports stop words, False otherwise.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
"""
pass
raise NotImplementedError("Subclasses must implement supports_stop_words()")
@abstractmethod
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
@@ -133,8 +239,11 @@ class BaseLLM(ABC):
Returns:
The context window size as an integer.
Raises:
NotImplementedError: If this method is not implemented by a subclass.
"""
pass
raise NotImplementedError("Subclasses must implement get_context_window_size()")
class FilteredStream:
@@ -229,7 +338,14 @@ def suppress_warnings():
sys.stderr = old_stderr
class LLM(BaseLLM):
class DefaultLLM(LLM):
"""Default LLM implementation using litellm.
This class provides a concrete implementation of the LLM interface
using litellm as the backend. It's the default implementation used
by CrewAI when no custom LLM is provided.
"""
def __init__(
self,
model: str,
@@ -255,6 +371,8 @@ class LLM(BaseLLM):
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
**kwargs,
):
super().__init__() # Initialize the base class
self.model = model
self.timeout = timeout
self.temperature = temperature
@@ -283,7 +401,7 @@ class LLM(BaseLLM):
# Normalize self.stop to always be a List[str]
if stop is None:
self.stop: List[str] = []
self.stop = [] # Already initialized in base class
elif isinstance(stop, str):
self.stop = [stop]
else:
@@ -667,3 +785,27 @@ class LLM(BaseLLM):
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
class BaseLLM(LLM):
"""Deprecated: Use LLM instead.
This class is kept for backward compatibility and will be removed in a future release.
It inherits from LLM and provides the same interface, but emits a deprecation warning
when instantiated.
"""
def __init__(self):
"""Initialize the BaseLLM with a deprecation warning.
This constructor emits a deprecation warning and then calls the parent class's
constructor to initialize the LLM.
"""
import warnings
warnings.warn(
"BaseLLM is deprecated and will be removed in a future release. "
"Use LLM instead for custom implementations.",
DeprecationWarning,
stacklevel=2
)
super().__init__()

View File

@@ -6,30 +6,30 @@ from crewai.llm import LLM, BaseLLM
def create_llm(
llm_value: Union[str, BaseLLM, Any, None] = None,
) -> Optional[BaseLLM]:
llm_value: Union[str, LLM, Any, None] = None,
) -> Optional[LLM]:
"""
Creates or returns an LLM instance based on the given llm_value.
Args:
llm_value (str | BaseLLM | Any | None):
llm_value (str | LLM | Any | None):
- str: The model name (e.g., "gpt-4").
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
- LLM: Already instantiated LLM, returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
Returns:
A BaseLLM instance if successful, or None if something fails.
A LLM instance if successful, or None if something fails.
"""
# 1) If llm_value is already a BaseLLM object, return it directly
if isinstance(llm_value, BaseLLM):
# 1) If llm_value is already a LLM object, return it directly
if isinstance(llm_value, LLM):
return llm_value
# 2) If llm_value is a string (model name)
if isinstance(llm_value, str):
try:
created_llm = LLM(model=llm_value)
created_llm = LLM.create(model=llm_value)
return created_llm
except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
@@ -56,7 +56,7 @@ def create_llm(
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)
created_llm = LLM(
created_llm = LLM.create(
model=model,
temperature=temperature,
max_tokens=max_tokens,
@@ -175,7 +175,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
# Try creating the LLM
try:
new_llm = LLM(**llm_params)
new_llm = LLM.create(**llm_params)
return new_llm
except Exception as e:
print(