diff --git a/docs/custom_llm.md b/docs/custom_llm.md new file mode 100644 index 000000000..245aa3ca4 --- /dev/null +++ b/docs/custom_llm.md @@ -0,0 +1,225 @@ +# Custom LLM Implementations + +CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism. + +## Using Custom LLM Implementations + +To create a custom LLM implementation, you need to: + +1. Inherit from the `BaseLLM` abstract base class +2. Implement the required methods: + - `call()`: The main method to call the LLM with messages + - `supports_function_calling()`: Whether the LLM supports function calling + - `supports_stop_words()`: Whether the LLM supports stop words + - `get_context_window_size()`: The context window size of the LLM + +## Example: Basic Custom LLM + +```python +from crewai import BaseLLM +from typing import Any, Dict, List, Optional, Union + +class CustomLLM(BaseLLM): + def __init__(self, api_key: str, endpoint: str): + super().__init__() # Initialize the base class to set default attributes + self.api_key = api_key + self.endpoint = endpoint + self.stop = [] # You can customize stop words if needed + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + # Implement your own logic to call the LLM + # For example, using requests: + import requests + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post(self.endpoint, headers=headers, json=data) + return response.json()["choices"][0]["message"]["content"] + + def supports_function_calling(self) -> bool: + # Return True if your LLM supports function calling + return True + + def supports_stop_words(self) -> bool: + # Return True if your LLM supports stop words + return True + + def get_context_window_size(self) -> int: + # Return the context window size of your LLM + return 8192 +``` + +## Example: JWT-based Authentication + +For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this: + +```python +from crewai import BaseLLM, Agent, Task +from typing import Any, Dict, List, Optional, Union + +class JWTAuthLLM(BaseLLM): + def __init__(self, jwt_token: str, endpoint: str): + super().__init__() # Initialize the base class to set default attributes + self.jwt_token = jwt_token + self.endpoint = endpoint + self.stop = [] # You can customize stop words if needed + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + # Implement your own logic to call the LLM with JWT authentication + import requests + + headers = { + "Authorization": f"Bearer {self.jwt_token}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post(self.endpoint, headers=headers, json=data) + return response.json()["choices"][0]["message"]["content"] + + def supports_function_calling(self) -> bool: + # Return True if your LLM supports function calling + return True + + def supports_stop_words(self) -> bool: + # Return True if your LLM supports stop words + return True + + def get_context_window_size(self) -> int: + # Return the context window size of your LLM + return 8192 +``` + +## Using Your Custom LLM with CrewAI + +Once you've implemented your custom LLM, you can use it with CrewAI agents and crews: + +```python +# Create your custom LLM instance +jwt_llm = JWTAuthLLM( + jwt_token="your.jwt.token", + endpoint="https://your-llm-endpoint.com/v1/chat/completions" +) + +# Use it with an agent +agent = Agent( + role="Research Assistant", + goal="Find information on a topic", + backstory="You are a research assistant tasked with finding information.", + llm=jwt_llm, +) + +# Create a task for the agent +task = Task( + description="Research the benefits of exercise", + agent=agent, + expected_output="A summary of the benefits of exercise", +) + +# Execute the task +result = agent.execute_task(task) +print(result) +``` + +## Advanced Usage: Function Calling + +If your LLM supports function calling, you can implement the function calling logic in your custom LLM: + +```python +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + import requests + + headers = { + "Authorization": f"Bearer {self.jwt_token}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post(self.endpoint, headers=headers, json=data) + response_data = response.json() + + # Check if the LLM wants to call a function + if response_data["choices"][0]["message"].get("tool_calls"): + tool_calls = response_data["choices"][0]["message"]["tool_calls"] + + # Process each tool call + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + + if available_functions and function_name in available_functions: + function_to_call = available_functions[function_name] + function_response = function_to_call(**function_args) + + # Add the function response to the messages + messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "name": function_name, + "content": str(function_response) + }) + + # Call the LLM again with the updated messages + return self.call(messages, tools, callbacks, available_functions) + + # Return the text response if no function call + return response_data["choices"][0]["message"]["content"] +``` + +## Implementing Your Own Authentication Mechanism + +The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use: + +- OAuth tokens +- Client certificates +- Custom headers +- Session-based authentication +- Any other authentication method required by your LLM provider + +Simply implement the appropriate authentication logic in your custom LLM class. diff --git a/src/crewai/__init__.py b/src/crewai/__init__.py index 662af2563..ca5b9ccb7 100644 --- a/src/crewai/__init__.py +++ b/src/crewai/__init__.py @@ -4,7 +4,7 @@ from crewai.agent import Agent from crewai.crew import Crew from crewai.flow.flow import Flow from crewai.knowledge.knowledge import Knowledge -from crewai.llm import LLM +from crewai.llm import BaseLLM, LLM from crewai.process import Process from crewai.task import Task @@ -21,6 +21,7 @@ __all__ = [ "Process", "Task", "LLM", + "BaseLLM", "Flow", "Knowledge", ] diff --git a/src/crewai/agent.py b/src/crewai/agent.py index cfebc18e5..fe1f829e9 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -11,7 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context -from crewai.llm import LLM +from crewai.llm import BaseLLM, LLM from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.task import Task from crewai.tools import BaseTool @@ -70,10 +70,10 @@ class Agent(BaseAgent): default=True, description="Use system prompt for the agent.", ) - llm: Union[str, InstanceOf[LLM], Any] = Field( + llm: Union[str, InstanceOf[BaseLLM], Any] = Field( description="Language model that will run the agent.", default=None ) - function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( + function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( description="Language model that will run the agent.", default=None ) system_template: Optional[str] = Field( @@ -117,7 +117,7 @@ class Agent(BaseAgent): self.agent_ops_agent_name = self.role self.llm = create_llm(self.llm) - if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM): + if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM): self.function_calling_llm = create_llm(self.function_calling_llm) if not self.agent_executor: diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 9cecfed3a..2e3308177 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -26,7 +26,7 @@ from crewai.agents.cache import CacheHandler from crewai.crews.crew_output import CrewOutput from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource -from crewai.llm import LLM +from crewai.llm import BaseLLM, LLM from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.long_term.long_term_memory import LongTermMemory from crewai.memory.short_term.short_term_memory import ShortTermMemory @@ -150,7 +150,7 @@ class Crew(BaseModel): default=None, description="Metrics for the LLM usage during all tasks execution.", ) - manager_llm: Optional[Any] = Field( + manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( description="Language model that will run the agent.", default=None ) manager_agent: Optional[BaseAgent] = Field( @@ -196,7 +196,7 @@ class Crew(BaseModel): default=False, description="Plan the crew execution and add the plan to the crew.", ) - planning_llm: Optional[Any] = Field( + planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( default=None, description="Language model that will run the AgentPlanner if planning is True.", ) @@ -212,7 +212,7 @@ class Crew(BaseModel): default=None, description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.", ) - chat_llm: Optional[Any] = Field( + chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( default=None, description="LLM used to handle chatting with the crew.", ) @@ -1193,7 +1193,7 @@ class Crew(BaseModel): def test( self, n_iterations: int, - eval_llm: Union[str, InstanceOf[LLM]], + eval_llm: Union[str, InstanceOf[BaseLLM]], inputs: Optional[Dict[str, Any]] = None, ) -> None: """Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.""" diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 0c8a46214..a4b3e637e 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -4,6 +4,7 @@ import os import sys import threading import warnings +from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Literal, Optional, Type, Union, cast @@ -34,6 +35,78 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import ( load_dotenv() +class BaseLLM(ABC): + """Abstract base class for LLM implementations. + + This class defines the interface that all LLM implementations must follow. + Users can extend this class to create custom LLM implementations that don't + rely on litellm's authentication mechanism. + """ + + def __init__(self): + """Initialize the BaseLLM with default attributes. + + This constructor sets default values for attributes that are expected + by the CrewAgentExecutor and other components. + """ + self.stop = [] + + @abstractmethod + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Call the LLM with the given messages. + + Args: + messages: Input messages for the LLM. + Can be a string or list of message dictionaries. + If string, it will be converted to a single user message. + If list, each dict must have 'role' and 'content' keys. + tools: Optional list of tool schemas for function calling. + Each tool should define its name, description, and parameters. + callbacks: Optional list of callback functions to be executed + during and after the LLM call. + available_functions: Optional dict mapping function names to callables + that can be invoked by the LLM. + + Returns: + Either a text response from the LLM (str) or + the result of a tool function call (Any). + """ + pass + + @abstractmethod + def supports_function_calling(self) -> bool: + """Check if the LLM supports function calling. + + Returns: + True if the LLM supports function calling, False otherwise. + """ + pass + + @abstractmethod + def supports_stop_words(self) -> bool: + """Check if the LLM supports stop words. + + Returns: + True if the LLM supports stop words, False otherwise. + """ + pass + + @abstractmethod + def get_context_window_size(self) -> int: + """Get the context window size of the LLM. + + Returns: + The context window size as an integer. + """ + pass + + class FilteredStream: def __init__(self, original_stream): self._original_stream = original_stream @@ -126,7 +199,7 @@ def suppress_warnings(): sys.stderr = old_stderr -class LLM: +class LLM(BaseLLM): def __init__( self, model: str, diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 4d34d789c..3271e8bcb 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -2,28 +2,28 @@ import os from typing import Any, Dict, List, Optional, Union from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS -from crewai.llm import LLM +from crewai.llm import BaseLLM, LLM def create_llm( - llm_value: Union[str, LLM, Any, None] = None, -) -> Optional[LLM]: + llm_value: Union[str, BaseLLM, Any, None] = None, +) -> Optional[BaseLLM]: """ Creates or returns an LLM instance based on the given llm_value. Args: - llm_value (str | LLM | Any | None): + llm_value (str | BaseLLM | Any | None): - str: The model name (e.g., "gpt-4"). - - LLM: Already instantiated LLM, returned as-is. + - BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is. - Any: Attempt to extract known attributes like model_name, temperature, etc. - None: Use environment-based or fallback default model. Returns: - An LLM instance if successful, or None if something fails. + A BaseLLM instance if successful, or None if something fails. """ - # 1) If llm_value is already an LLM object, return it directly - if isinstance(llm_value, LLM): + # 1) If llm_value is already a BaseLLM object, return it directly + if isinstance(llm_value, BaseLLM): return llm_value # 2) If llm_value is a string (model name) diff --git a/tests/custom_llm_test.py b/tests/custom_llm_test.py new file mode 100644 index 000000000..afedafbd5 --- /dev/null +++ b/tests/custom_llm_test.py @@ -0,0 +1,111 @@ +import pytest +from typing import Any, Dict, List, Optional, Union + +from crewai.llm import BaseLLM +from crewai.utilities.llm_utils import create_llm + + +class CustomLLM(BaseLLM): + """Custom LLM implementation for testing.""" + + def __init__(self, response: str = "Custom LLM response"): + self.response = response + self.calls = [] + self.stop = [] + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Record the call and return the predefined response.""" + self.calls.append({ + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions + }) + return self.response + + def supports_function_calling(self) -> bool: + """Return True to indicate that function calling is supported.""" + return True + + def supports_stop_words(self) -> bool: + """Return True to indicate that stop words are supported.""" + return True + + def get_context_window_size(self) -> int: + """Return a default context window size.""" + return 8192 + + +def test_custom_llm_implementation(): + """Test that a custom LLM implementation works with create_llm.""" + custom_llm = CustomLLM(response="The answer is 42") + + # Test that create_llm returns the custom LLM instance directly + result_llm = create_llm(custom_llm) + + assert result_llm is custom_llm + + # Test calling the custom LLM + response = result_llm.call("What is the answer to life, the universe, and everything?") + + # Verify that the custom LLM was called + assert len(custom_llm.calls) > 0 + # Verify that the response from the custom LLM was used + assert response == "The answer is 42" + + +class JWTAuthLLM(BaseLLM): + def __init__(self, jwt_token: str): + self.jwt_token = jwt_token + self.calls = [] + self.stop = [] + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + self.calls.append({ + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions + }) + # In a real implementation, this would use the JWT token to authenticate + # with an external service + return "Response from JWT-authenticated LLM" + + def supports_function_calling(self) -> bool: + return True + + def supports_stop_words(self) -> bool: + return True + + def get_context_window_size(self) -> int: + return 8192 + + +def test_custom_llm_with_jwt_auth(): + """Test a custom LLM implementation with JWT authentication.""" + jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token") + + # Test that create_llm returns the JWT-authenticated LLM instance directly + result_llm = create_llm(jwt_llm) + + assert result_llm is jwt_llm + + # Test calling the JWT-authenticated LLM + response = result_llm.call("Test message") + + # Verify that the JWT-authenticated LLM was called + assert len(jwt_llm.calls) > 0 + # Verify that the response from the JWT-authenticated LLM was used + assert response == "Response from JWT-authenticated LLM"