Add support for custom LLM implementations

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-03-04 17:09:17 +00:00
parent 00eede0d5d
commit ec8e705bbc
7 changed files with 429 additions and 19 deletions

225
docs/custom_llm.md Normal file
View File

@@ -0,0 +1,225 @@
# Custom LLM Implementations
CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
## Using Custom LLM Implementations
To create a custom LLM implementation, you need to:
1. Inherit from the `BaseLLM` abstract base class
2. Implement the required methods:
- `call()`: The main method to call the LLM with messages
- `supports_function_calling()`: Whether the LLM supports function calling
- `supports_stop_words()`: Whether the LLM supports stop words
- `get_context_window_size()`: The context window size of the LLM
## Example: Basic Custom LLM
```python
from crewai import BaseLLM
from typing import Any, Dict, List, Optional, Union
class CustomLLM(BaseLLM):
def __init__(self, api_key: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes
self.api_key = api_key
self.endpoint = endpoint
self.stop = [] # You can customize stop words if needed
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
# Implement your own logic to call the LLM
# For example, using requests:
import requests
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
data = {
"messages": messages,
"tools": tools
}
response = requests.post(self.endpoint, headers=headers, json=data)
return response.json()["choices"][0]["message"]["content"]
def supports_function_calling(self) -> bool:
# Return True if your LLM supports function calling
return True
def supports_stop_words(self) -> bool:
# Return True if your LLM supports stop words
return True
def get_context_window_size(self) -> int:
# Return the context window size of your LLM
return 8192
```
## Example: JWT-based Authentication
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
```python
from crewai import BaseLLM, Agent, Task
from typing import Any, Dict, List, Optional, Union
class JWTAuthLLM(BaseLLM):
def __init__(self, jwt_token: str, endpoint: str):
super().__init__() # Initialize the base class to set default attributes
self.jwt_token = jwt_token
self.endpoint = endpoint
self.stop = [] # You can customize stop words if needed
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
# Implement your own logic to call the LLM with JWT authentication
import requests
headers = {
"Authorization": f"Bearer {self.jwt_token}",
"Content-Type": "application/json"
}
# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
data = {
"messages": messages,
"tools": tools
}
response = requests.post(self.endpoint, headers=headers, json=data)
return response.json()["choices"][0]["message"]["content"]
def supports_function_calling(self) -> bool:
# Return True if your LLM supports function calling
return True
def supports_stop_words(self) -> bool:
# Return True if your LLM supports stop words
return True
def get_context_window_size(self) -> int:
# Return the context window size of your LLM
return 8192
```
## Using Your Custom LLM with CrewAI
Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:
```python
# Create your custom LLM instance
jwt_llm = JWTAuthLLM(
jwt_token="your.jwt.token",
endpoint="https://your-llm-endpoint.com/v1/chat/completions"
)
# Use it with an agent
agent = Agent(
role="Research Assistant",
goal="Find information on a topic",
backstory="You are a research assistant tasked with finding information.",
llm=jwt_llm,
)
# Create a task for the agent
task = Task(
description="Research the benefits of exercise",
agent=agent,
expected_output="A summary of the benefits of exercise",
)
# Execute the task
result = agent.execute_task(task)
print(result)
```
## Advanced Usage: Function Calling
If your LLM supports function calling, you can implement the function calling logic in your custom LLM:
```python
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
import requests
headers = {
"Authorization": f"Bearer {self.jwt_token}",
"Content-Type": "application/json"
}
# Convert string message to proper format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
data = {
"messages": messages,
"tools": tools
}
response = requests.post(self.endpoint, headers=headers, json=data)
response_data = response.json()
# Check if the LLM wants to call a function
if response_data["choices"][0]["message"].get("tool_calls"):
tool_calls = response_data["choices"][0]["message"]["tool_calls"]
# Process each tool call
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
function_args = json.loads(tool_call["function"]["arguments"])
if available_functions and function_name in available_functions:
function_to_call = available_functions[function_name]
function_response = function_to_call(**function_args)
# Add the function response to the messages
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"name": function_name,
"content": str(function_response)
})
# Call the LLM again with the updated messages
return self.call(messages, tools, callbacks, available_functions)
# Return the text response if no function call
return response_data["choices"][0]["message"]["content"]
```
## Implementing Your Own Authentication Mechanism
The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
- OAuth tokens
- Client certificates
- Custom headers
- Session-based authentication
- Any other authentication method required by your LLM provider
Simply implement the appropriate authentication logic in your custom LLM class.

View File

@@ -4,7 +4,7 @@ from crewai.agent import Agent
from crewai.crew import Crew
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.llm import BaseLLM, LLM
from crewai.process import Process
from crewai.task import Task
@@ -21,6 +21,7 @@ __all__ = [
"Process",
"Task",
"LLM",
"BaseLLM",
"Flow",
"Knowledge",
]

View File

@@ -11,7 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.llm import LLM
from crewai.llm import BaseLLM, LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
from crewai.task import Task
from crewai.tools import BaseTool
@@ -70,10 +70,10 @@ class Agent(BaseAgent):
default=True,
description="Use system prompt for the agent.",
)
llm: Union[str, InstanceOf[LLM], Any] = Field(
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
system_template: Optional[str] = Field(
@@ -117,7 +117,7 @@ class Agent(BaseAgent):
self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm)
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM):
self.function_calling_llm = create_llm(self.function_calling_llm)
if not self.agent_executor:

View File

@@ -26,7 +26,7 @@ from crewai.agents.cache import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.llm import BaseLLM, LLM
from crewai.memory.entity.entity_memory import EntityMemory
from crewai.memory.long_term.long_term_memory import LongTermMemory
from crewai.memory.short_term.short_term_memory import ShortTermMemory
@@ -150,7 +150,7 @@ class Crew(BaseModel):
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: Optional[Any] = Field(
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: Optional[BaseAgent] = Field(
@@ -196,7 +196,7 @@ class Crew(BaseModel):
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: Optional[Any] = Field(
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
default=None,
description="Language model that will run the AgentPlanner if planning is True.",
)
@@ -212,7 +212,7 @@ class Crew(BaseModel):
default=None,
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
)
chat_llm: Optional[Any] = Field(
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
@@ -1193,7 +1193,7 @@ class Crew(BaseModel):
def test(
self,
n_iterations: int,
eval_llm: Union[str, InstanceOf[LLM]],
eval_llm: Union[str, InstanceOf[BaseLLM]],
inputs: Optional[Dict[str, Any]] = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""

View File

@@ -4,6 +4,7 @@ import os
import sys
import threading
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
@@ -34,6 +35,78 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
load_dotenv()
class BaseLLM(ABC):
"""Abstract base class for LLM implementations.
This class defines the interface that all LLM implementations must follow.
Users can extend this class to create custom LLM implementations that don't
rely on litellm's authentication mechanism.
"""
def __init__(self):
"""Initialize the BaseLLM with default attributes.
This constructor sets default values for attributes that are expected
by the CrewAgentExecutor and other components.
"""
self.stop = []
@abstractmethod
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with the given messages.
Args:
messages: Input messages for the LLM.
Can be a string or list of message dictionaries.
If string, it will be converted to a single user message.
If list, each dict must have 'role' and 'content' keys.
tools: Optional list of tool schemas for function calling.
Each tool should define its name, description, and parameters.
callbacks: Optional list of callback functions to be executed
during and after the LLM call.
available_functions: Optional dict mapping function names to callables
that can be invoked by the LLM.
Returns:
Either a text response from the LLM (str) or
the result of a tool function call (Any).
"""
pass
@abstractmethod
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
pass
@abstractmethod
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
True if the LLM supports stop words, False otherwise.
"""
pass
@abstractmethod
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
Returns:
The context window size as an integer.
"""
pass
class FilteredStream:
def __init__(self, original_stream):
self._original_stream = original_stream
@@ -126,7 +199,7 @@ def suppress_warnings():
sys.stderr = old_stderr
class LLM:
class LLM(BaseLLM):
def __init__(
self,
model: str,

View File

@@ -2,28 +2,28 @@ import os
from typing import Any, Dict, List, Optional, Union
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM
from crewai.llm import BaseLLM, LLM
def create_llm(
llm_value: Union[str, LLM, Any, None] = None,
) -> Optional[LLM]:
llm_value: Union[str, BaseLLM, Any, None] = None,
) -> Optional[BaseLLM]:
"""
Creates or returns an LLM instance based on the given llm_value.
Args:
llm_value (str | LLM | Any | None):
llm_value (str | BaseLLM | Any | None):
- str: The model name (e.g., "gpt-4").
- LLM: Already instantiated LLM, returned as-is.
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
Returns:
An LLM instance if successful, or None if something fails.
A BaseLLM instance if successful, or None if something fails.
"""
# 1) If llm_value is already an LLM object, return it directly
if isinstance(llm_value, LLM):
# 1) If llm_value is already a BaseLLM object, return it directly
if isinstance(llm_value, BaseLLM):
return llm_value
# 2) If llm_value is a string (model name)

111
tests/custom_llm_test.py Normal file
View File

@@ -0,0 +1,111 @@
import pytest
from typing import Any, Dict, List, Optional, Union
from crewai.llm import BaseLLM
from crewai.utilities.llm_utils import create_llm
class CustomLLM(BaseLLM):
"""Custom LLM implementation for testing."""
def __init__(self, response: str = "Custom LLM response"):
self.response = response
self.calls = []
self.stop = []
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Record the call and return the predefined response."""
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})
return self.response
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported."""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported."""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size."""
return 8192
def test_custom_llm_implementation():
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="The answer is 42")
# Test that create_llm returns the custom LLM instance directly
result_llm = create_llm(custom_llm)
assert result_llm is custom_llm
# Test calling the custom LLM
response = result_llm.call("What is the answer to life, the universe, and everything?")
# Verify that the custom LLM was called
assert len(custom_llm.calls) > 0
# Verify that the response from the custom LLM was used
assert response == "The answer is 42"
class JWTAuthLLM(BaseLLM):
def __init__(self, jwt_token: str):
self.jwt_token = jwt_token
self.calls = []
self.stop = []
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})
# In a real implementation, this would use the JWT token to authenticate
# with an external service
return "Response from JWT-authenticated LLM"
def supports_function_calling(self) -> bool:
return True
def supports_stop_words(self) -> bool:
return True
def get_context_window_size(self) -> int:
return 8192
def test_custom_llm_with_jwt_auth():
"""Test a custom LLM implementation with JWT authentication."""
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
# Test that create_llm returns the JWT-authenticated LLM instance directly
result_llm = create_llm(jwt_llm)
assert result_llm is jwt_llm
# Test calling the JWT-authenticated LLM
response = result_llm.call("Test message")
# Verify that the JWT-authenticated LLM was called
assert len(jwt_llm.calls) > 0
# Verify that the response from the JWT-authenticated LLM was used
assert response == "Response from JWT-authenticated LLM"