diff --git a/src/crewai/agent.py b/src/crewai/agent.py index fe1f829e9..f0ee25718 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -11,7 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context -from crewai.llm import BaseLLM, LLM +from crewai.llm import LLM, BaseLLM from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.task import Task from crewai.tools import BaseTool diff --git a/src/crewai/cli/crew_chat.py b/src/crewai/cli/crew_chat.py index 34a0ae2ce..e730935f3 100644 --- a/src/crewai/cli/crew_chat.py +++ b/src/crewai/cli/crew_chat.py @@ -14,7 +14,7 @@ from packaging import version from crewai.cli.utils import read_toml from crewai.cli.version import get_crewai_version from crewai.crew import Crew -from crewai.llm import BaseLLM, LLM +from crewai.llm import LLM, BaseLLM from crewai.types.crew_chat import ChatInputField, ChatInputs from crewai.utilities.llm_utils import create_llm diff --git a/src/crewai/llm.py b/src/crewai/llm.py index a4b3e637e..7146b73ae 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -41,6 +41,15 @@ class BaseLLM(ABC): This class defines the interface that all LLM implementations must follow. Users can extend this class to create custom LLM implementations that don't rely on litellm's authentication mechanism. + + Custom LLM implementations should handle error cases gracefully, including + timeouts, authentication failures, and malformed responses. They should also + implement proper validation for input parameters and provide clear error + messages when things go wrong. + + Attributes: + stop (list): A list of stop sequences that the LLM should use to stop generation. + This is used by the CrewAgentExecutor and other components. """ def __init__(self): @@ -48,6 +57,9 @@ class BaseLLM(ABC): This constructor sets default values for attributes that are expected by the CrewAgentExecutor and other components. + + All custom LLM implementations should call super().__init__() to ensure + that these default attributes are properly initialized. """ self.stop = [] @@ -76,6 +88,11 @@ class BaseLLM(ABC): Returns: Either a text response from the LLM (str) or the result of a tool function call (Any). + + Raises: + ValueError: If the messages format is invalid. + TimeoutError: If the LLM request times out. + RuntimeError: If the LLM request fails for other reasons. """ pass @@ -83,6 +100,11 @@ class BaseLLM(ABC): def supports_function_calling(self) -> bool: """Check if the LLM supports function calling. + This method should return True if the LLM implementation supports + function calling (tools), and False otherwise. If this method returns + True, the LLM should be able to handle the 'tools' parameter in the + call() method. + Returns: True if the LLM supports function calling, False otherwise. """ @@ -92,6 +114,10 @@ class BaseLLM(ABC): def supports_stop_words(self) -> bool: """Check if the LLM supports stop words. + This method should return True if the LLM implementation supports + stop words, and False otherwise. If this method returns True, the + LLM should respect the 'stop' attribute when generating responses. + Returns: True if the LLM supports stop words, False otherwise. """ @@ -101,6 +127,10 @@ class BaseLLM(ABC): def get_context_window_size(self) -> int: """Get the context window size of the LLM. + This method should return the maximum number of tokens that the LLM + can process in a single request. This is used by CrewAI to ensure + that messages don't exceed the LLM's context window. + Returns: The context window size as an integer. """ diff --git a/tests/custom_llm_test.py b/tests/custom_llm_test.py index fcbdfc52a..b833a57a0 100644 --- a/tests/custom_llm_test.py +++ b/tests/custom_llm_test.py @@ -62,7 +62,12 @@ def test_custom_llm_implementation(): class JWTAuthLLM(BaseLLM): + """Custom LLM implementation with JWT authentication.""" + def __init__(self, jwt_token: str): + super().__init__() + if not jwt_token or not isinstance(jwt_token, str): + raise ValueError("Invalid JWT token") self.jwt_token = jwt_token self.calls = [] self.stop = [] @@ -74,6 +79,7 @@ class JWTAuthLLM(BaseLLM): callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, ) -> Union[str, Any]: + """Record the call and return a predefined response.""" self.calls.append({ "messages": messages, "tools": tools, @@ -85,12 +91,15 @@ class JWTAuthLLM(BaseLLM): return "Response from JWT-authenticated LLM" def supports_function_calling(self) -> bool: + """Return True to indicate that function calling is supported.""" return True def supports_stop_words(self) -> bool: + """Return True to indicate that stop words are supported.""" return True def get_context_window_size(self) -> int: + """Return a default context window size.""" return 8192