diff --git a/docs/custom_llm.md b/docs/custom_llm.md index 3d0fdc0c4..0e44cca5d 100644 --- a/docs/custom_llm.md +++ b/docs/custom_llm.md @@ -1,25 +1,44 @@ # Custom LLM Implementations -CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism. +CrewAI supports custom LLM implementations through the `LLM` base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism. ## Using Custom LLM Implementations To create a custom LLM implementation, you need to: -1. Inherit from the `BaseLLM` abstract base class +1. Inherit from the `LLM` base class 2. Implement the required methods: - `call()`: The main method to call the LLM with messages - `supports_function_calling()`: Whether the LLM supports function calling - `supports_stop_words()`: Whether the LLM supports stop words - `get_context_window_size()`: The context window size of the LLM +## Using the Default LLM Implementation + +If you don't need a custom LLM implementation, you can use the default implementation provided by CrewAI: + +```python +from crewai import LLM + +# Create a default LLM instance +llm = LLM.create(model="gpt-4") + +# Or with more parameters +llm = LLM.create( + model="gpt-4", + temperature=0.7, + max_tokens=1000, + api_key="your-api-key" +) +``` + ## Example: Basic Custom LLM ```python -from crewai import BaseLLM +from crewai import LLM from typing import Any, Dict, List, Optional, Union -class CustomLLM(BaseLLM): +class CustomLLM(LLM): def __init__(self, api_key: str, endpoint: str): super().__init__() # Initialize the base class to set default attributes if not api_key or not isinstance(api_key, str): @@ -230,10 +249,10 @@ def call( For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this: ```python -from crewai import BaseLLM, Agent, Task +from crewai import LLM, Agent, Task from typing import Any, Dict, List, Optional, Union -class JWTAuthLLM(BaseLLM): +class JWTAuthLLM(LLM): def __init__(self, jwt_token: str, endpoint: str): super().__init__() # Initialize the base class to set default attributes if not jwt_token or not isinstance(jwt_token, str): @@ -631,7 +650,7 @@ print(result) ## Implementing Your Own Authentication Mechanism -The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use: +The `LLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use: - OAuth tokens - Client certificates @@ -640,3 +659,23 @@ The `BaseLLM` class allows you to implement any authentication mechanism you nee - Any other authentication method required by your LLM provider Simply implement the appropriate authentication logic in your custom LLM class. + +## Migrating from BaseLLM to LLM + +If you were previously using `BaseLLM`, you can simply replace it with `LLM`: + +```python +# Old code +from crewai import BaseLLM + +class CustomLLM(BaseLLM): + # ... + +# New code +from crewai import LLM + +class CustomLLM(LLM): + # ... +``` + +The `BaseLLM` class is still available for backward compatibility but will be removed in a future release. It now inherits from `LLM` and emits a deprecation warning when instantiated. diff --git a/src/crewai/__init__.py b/src/crewai/__init__.py index 0d6b06961..98ad92ca3 100644 --- a/src/crewai/__init__.py +++ b/src/crewai/__init__.py @@ -4,7 +4,7 @@ from crewai.agent import Agent from crewai.crew import Crew from crewai.flow.flow import Flow from crewai.knowledge.knowledge import Knowledge -from crewai.llm import LLM, BaseLLM +from crewai.llm import LLM, BaseLLM, DefaultLLM from crewai.process import Process from crewai.task import Task @@ -22,6 +22,7 @@ __all__ = [ "Task", "LLM", "BaseLLM", + "DefaultLLM", "Flow", "Knowledge", ] diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 7146b73ae..fa64ccd6c 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -35,8 +35,8 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( load_dotenv() -class BaseLLM(ABC): - """Abstract base class for LLM implementations. +class LLM(ABC): + """Base class for LLM implementations. This class defines the interface that all LLM implementations must follow. Users can extend this class to create custom LLM implementations that don't @@ -52,8 +52,26 @@ class BaseLLM(ABC): This is used by the CrewAgentExecutor and other components. """ + def __new__(cls, *args, **kwargs): + """Create a new LLM instance. + + This method handles backward compatibility by creating a DefaultLLM instance + when the LLM class is instantiated directly with parameters. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + Either a new LLM instance or a DefaultLLM instance for backward compatibility. + """ + if cls is LLM and (args or kwargs.get('model') is not None): + from crewai.llm import DefaultLLM + return DefaultLLM(*args, **kwargs) + return super().__new__(cls) + def __init__(self): - """Initialize the BaseLLM with default attributes. + """Initialize the LLM with default attributes. This constructor sets default values for attributes that are expected by the CrewAgentExecutor and other components. @@ -63,7 +81,91 @@ class BaseLLM(ABC): """ self.stop = [] - @abstractmethod + @classmethod + def create( + cls, + model: str, + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[int, float]] = None, + response_format: Optional[Type[BaseModel]] = None, + seed: Optional[int] = None, + logprobs: Optional[int] = None, + top_logprobs: Optional[int] = None, + base_url: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + callbacks: List[Any] = [], + reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, + **kwargs, + ) -> 'DefaultLLM': + """Create a default LLM instance using litellm. + + This factory method creates a default LLM instance using litellm as the backend. + It's the recommended way to create LLM instances for most users. + + Args: + model: The model name (e.g., "gpt-4"). + timeout: Optional timeout for the LLM call. + temperature: Optional temperature for the LLM call. + top_p: Optional top_p for the LLM call. + n: Optional n for the LLM call. + stop: Optional stop sequences for the LLM call. + max_completion_tokens: Optional max_completion_tokens for the LLM call. + max_tokens: Optional max_tokens for the LLM call. + presence_penalty: Optional presence_penalty for the LLM call. + frequency_penalty: Optional frequency_penalty for the LLM call. + logit_bias: Optional logit_bias for the LLM call. + response_format: Optional response_format for the LLM call. + seed: Optional seed for the LLM call. + logprobs: Optional logprobs for the LLM call. + top_logprobs: Optional top_logprobs for the LLM call. + base_url: Optional base_url for the LLM call. + api_base: Optional api_base for the LLM call. + api_version: Optional api_version for the LLM call. + api_key: Optional api_key for the LLM call. + callbacks: Optional callbacks for the LLM call. + reasoning_effort: Optional reasoning_effort for the LLM call. + **kwargs: Additional keyword arguments for the LLM call. + + Returns: + A DefaultLLM instance configured with the provided parameters. + """ + from crewai.llm import DefaultLLM + + return DefaultLLM( + model=model, + timeout=timeout, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + response_format=response_format, + seed=seed, + logprobs=logprobs, + top_logprobs=top_logprobs, + base_url=base_url, + api_base=api_base, + api_version=api_version, + api_key=api_key, + callbacks=callbacks, + reasoning_effort=reasoning_effort, + **kwargs, + ) + def call( self, messages: Union[str, List[Dict[str, str]]], @@ -93,10 +195,10 @@ class BaseLLM(ABC): ValueError: If the messages format is invalid. TimeoutError: If the LLM request times out. RuntimeError: If the LLM request fails for other reasons. + NotImplementedError: If this method is not implemented by a subclass. """ - pass + raise NotImplementedError("Subclasses must implement call()") - @abstractmethod def supports_function_calling(self) -> bool: """Check if the LLM supports function calling. @@ -107,10 +209,12 @@ class BaseLLM(ABC): Returns: True if the LLM supports function calling, False otherwise. + + Raises: + NotImplementedError: If this method is not implemented by a subclass. """ - pass + raise NotImplementedError("Subclasses must implement supports_function_calling()") - @abstractmethod def supports_stop_words(self) -> bool: """Check if the LLM supports stop words. @@ -120,10 +224,12 @@ class BaseLLM(ABC): Returns: True if the LLM supports stop words, False otherwise. + + Raises: + NotImplementedError: If this method is not implemented by a subclass. """ - pass + raise NotImplementedError("Subclasses must implement supports_stop_words()") - @abstractmethod def get_context_window_size(self) -> int: """Get the context window size of the LLM. @@ -133,8 +239,11 @@ class BaseLLM(ABC): Returns: The context window size as an integer. + + Raises: + NotImplementedError: If this method is not implemented by a subclass. """ - pass + raise NotImplementedError("Subclasses must implement get_context_window_size()") class FilteredStream: @@ -229,7 +338,14 @@ def suppress_warnings(): sys.stderr = old_stderr -class LLM(BaseLLM): +class DefaultLLM(LLM): + """Default LLM implementation using litellm. + + This class provides a concrete implementation of the LLM interface + using litellm as the backend. It's the default implementation used + by CrewAI when no custom LLM is provided. + """ + def __init__( self, model: str, @@ -255,6 +371,8 @@ class LLM(BaseLLM): reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, **kwargs, ): + super().__init__() # Initialize the base class + self.model = model self.timeout = timeout self.temperature = temperature @@ -283,7 +401,7 @@ class LLM(BaseLLM): # Normalize self.stop to always be a List[str] if stop is None: - self.stop: List[str] = [] + self.stop = [] # Already initialized in base class elif isinstance(stop, str): self.stop = [stop] else: @@ -667,3 +785,27 @@ class LLM(BaseLLM): litellm.success_callback = success_callbacks litellm.failure_callback = failure_callbacks + + +class BaseLLM(LLM): + """Deprecated: Use LLM instead. + + This class is kept for backward compatibility and will be removed in a future release. + It inherits from LLM and provides the same interface, but emits a deprecation warning + when instantiated. + """ + + def __init__(self): + """Initialize the BaseLLM with a deprecation warning. + + This constructor emits a deprecation warning and then calls the parent class's + constructor to initialize the LLM. + """ + import warnings + warnings.warn( + "BaseLLM is deprecated and will be removed in a future release. " + "Use LLM instead for custom implementations.", + DeprecationWarning, + stacklevel=2 + ) + super().__init__() diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 8035b6593..63c5a9441 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -6,30 +6,30 @@ from crewai.llm import LLM, BaseLLM def create_llm( - llm_value: Union[str, BaseLLM, Any, None] = None, -) -> Optional[BaseLLM]: + llm_value: Union[str, LLM, Any, None] = None, +) -> Optional[LLM]: """ Creates or returns an LLM instance based on the given llm_value. Args: - llm_value (str | BaseLLM | Any | None): + llm_value (str | LLM | Any | None): - str: The model name (e.g., "gpt-4"). - - BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is. + - LLM: Already instantiated LLM, returned as-is. - Any: Attempt to extract known attributes like model_name, temperature, etc. - None: Use environment-based or fallback default model. Returns: - A BaseLLM instance if successful, or None if something fails. + A LLM instance if successful, or None if something fails. """ - # 1) If llm_value is already a BaseLLM object, return it directly - if isinstance(llm_value, BaseLLM): + # 1) If llm_value is already a LLM object, return it directly + if isinstance(llm_value, LLM): return llm_value # 2) If llm_value is a string (model name) if isinstance(llm_value, str): try: - created_llm = LLM(model=llm_value) + created_llm = LLM.create(model=llm_value) return created_llm except Exception as e: print(f"Failed to instantiate LLM with model='{llm_value}': {e}") @@ -56,7 +56,7 @@ def create_llm( base_url: Optional[str] = getattr(llm_value, "base_url", None) api_base: Optional[str] = getattr(llm_value, "api_base", None) - created_llm = LLM( + created_llm = LLM.create( model=model, temperature=temperature, max_tokens=max_tokens, @@ -175,7 +175,7 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: # Try creating the LLM try: - new_llm = LLM(**llm_params) + new_llm = LLM.create(**llm_params) return new_llm except Exception as e: print( diff --git a/tests/custom_llm_test.py b/tests/custom_llm_test.py index c3b0de1c0..7cef215fa 100644 --- a/tests/custom_llm_test.py +++ b/tests/custom_llm_test.py @@ -2,14 +2,14 @@ from typing import Any, Dict, List, Optional, Union import pytest -from crewai.llm import BaseLLM +from crewai.llm import LLM from crewai.utilities.llm_utils import create_llm -class CustomLLM(BaseLLM): +class CustomLLM(LLM): """Custom LLM implementation for testing. - This is a simple implementation of the BaseLLM abstract base class + This is a simple implementation of the LLM abstract base class that returns a predefined response for testing purposes. """ @@ -93,7 +93,7 @@ def test_custom_llm_implementation(): assert response == "The answer is 42" -class JWTAuthLLM(BaseLLM): +class JWTAuthLLM(LLM): """Custom LLM implementation with JWT authentication.""" def __init__(self, jwt_token: str): @@ -164,7 +164,7 @@ def test_jwt_auth_llm_validation(): JWTAuthLLM(jwt_token=None) -class TimeoutHandlingLLM(BaseLLM): +class TimeoutHandlingLLM(LLM): """Custom LLM implementation with timeout handling and retry logic.""" def __init__(self, max_retries: int = 3, timeout: int = 30):