mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Add support for custom LLM implementations
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
225
docs/custom_llm.md
Normal file
225
docs/custom_llm.md
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
# Custom LLM Implementations
|
||||||
|
|
||||||
|
CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.
|
||||||
|
|
||||||
|
## Using Custom LLM Implementations
|
||||||
|
|
||||||
|
To create a custom LLM implementation, you need to:
|
||||||
|
|
||||||
|
1. Inherit from the `BaseLLM` abstract base class
|
||||||
|
2. Implement the required methods:
|
||||||
|
- `call()`: The main method to call the LLM with messages
|
||||||
|
- `supports_function_calling()`: Whether the LLM supports function calling
|
||||||
|
- `supports_stop_words()`: Whether the LLM supports stop words
|
||||||
|
- `get_context_window_size()`: The context window size of the LLM
|
||||||
|
|
||||||
|
## Example: Basic Custom LLM
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import BaseLLM
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
class CustomLLM(BaseLLM):
|
||||||
|
def __init__(self, api_key: str, endpoint: str):
|
||||||
|
super().__init__() # Initialize the base class to set default attributes
|
||||||
|
self.api_key = api_key
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.stop = [] # You can customize stop words if needed
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
# Implement your own logic to call the LLM
|
||||||
|
# For example, using requests:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert string message to proper format if needed
|
||||||
|
if isinstance(messages, str):
|
||||||
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(self.endpoint, headers=headers, json=data)
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
# Return True if your LLM supports function calling
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
# Return True if your LLM supports stop words
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
# Return the context window size of your LLM
|
||||||
|
return 8192
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example: JWT-based Authentication
|
||||||
|
|
||||||
|
For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from crewai import BaseLLM, Agent, Task
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
class JWTAuthLLM(BaseLLM):
|
||||||
|
def __init__(self, jwt_token: str, endpoint: str):
|
||||||
|
super().__init__() # Initialize the base class to set default attributes
|
||||||
|
self.jwt_token = jwt_token
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.stop = [] # You can customize stop words if needed
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
# Implement your own logic to call the LLM with JWT authentication
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.jwt_token}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert string message to proper format if needed
|
||||||
|
if isinstance(messages, str):
|
||||||
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(self.endpoint, headers=headers, json=data)
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
# Return True if your LLM supports function calling
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
# Return True if your LLM supports stop words
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
# Return the context window size of your LLM
|
||||||
|
return 8192
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using Your Custom LLM with CrewAI
|
||||||
|
|
||||||
|
Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Create your custom LLM instance
|
||||||
|
jwt_llm = JWTAuthLLM(
|
||||||
|
jwt_token="your.jwt.token",
|
||||||
|
endpoint="https://your-llm-endpoint.com/v1/chat/completions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use it with an agent
|
||||||
|
agent = Agent(
|
||||||
|
role="Research Assistant",
|
||||||
|
goal="Find information on a topic",
|
||||||
|
backstory="You are a research assistant tasked with finding information.",
|
||||||
|
llm=jwt_llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a task for the agent
|
||||||
|
task = Task(
|
||||||
|
description="Research the benefits of exercise",
|
||||||
|
agent=agent,
|
||||||
|
expected_output="A summary of the benefits of exercise",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the task
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced Usage: Function Calling
|
||||||
|
|
||||||
|
If your LLM supports function calling, you can implement the function calling logic in your custom LLM:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
import requests
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.jwt_token}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert string message to proper format if needed
|
||||||
|
if isinstance(messages, str):
|
||||||
|
messages = [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(self.endpoint, headers=headers, json=data)
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
# Check if the LLM wants to call a function
|
||||||
|
if response_data["choices"][0]["message"].get("tool_calls"):
|
||||||
|
tool_calls = response_data["choices"][0]["message"]["tool_calls"]
|
||||||
|
|
||||||
|
# Process each tool call
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
function_name = tool_call["function"]["name"]
|
||||||
|
function_args = json.loads(tool_call["function"]["arguments"])
|
||||||
|
|
||||||
|
if available_functions and function_name in available_functions:
|
||||||
|
function_to_call = available_functions[function_name]
|
||||||
|
function_response = function_to_call(**function_args)
|
||||||
|
|
||||||
|
# Add the function response to the messages
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call["id"],
|
||||||
|
"name": function_name,
|
||||||
|
"content": str(function_response)
|
||||||
|
})
|
||||||
|
|
||||||
|
# Call the LLM again with the updated messages
|
||||||
|
return self.call(messages, tools, callbacks, available_functions)
|
||||||
|
|
||||||
|
# Return the text response if no function call
|
||||||
|
return response_data["choices"][0]["message"]["content"]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementing Your Own Authentication Mechanism
|
||||||
|
|
||||||
|
The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:
|
||||||
|
|
||||||
|
- OAuth tokens
|
||||||
|
- Client certificates
|
||||||
|
- Custom headers
|
||||||
|
- Session-based authentication
|
||||||
|
- Any other authentication method required by your LLM provider
|
||||||
|
|
||||||
|
Simply implement the appropriate authentication logic in your custom LLM class.
|
||||||
@@ -4,7 +4,7 @@ from crewai.agent import Agent
|
|||||||
from crewai.crew import Crew
|
from crewai.crew import Crew
|
||||||
from crewai.flow.flow import Flow
|
from crewai.flow.flow import Flow
|
||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.llm import LLM
|
from crewai.llm import BaseLLM, LLM
|
||||||
from crewai.process import Process
|
from crewai.process import Process
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
|
|
||||||
@@ -21,6 +21,7 @@ __all__ = [
|
|||||||
"Process",
|
"Process",
|
||||||
"Task",
|
"Task",
|
||||||
"LLM",
|
"LLM",
|
||||||
|
"BaseLLM",
|
||||||
"Flow",
|
"Flow",
|
||||||
"Knowledge",
|
"Knowledge",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
|||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||||
from crewai.llm import LLM
|
from crewai.llm import BaseLLM, LLM
|
||||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
@@ -70,10 +70,10 @@ class Agent(BaseAgent):
|
|||||||
default=True,
|
default=True,
|
||||||
description="Use system prompt for the agent.",
|
description="Use system prompt for the agent.",
|
||||||
)
|
)
|
||||||
llm: Union[str, InstanceOf[LLM], Any] = Field(
|
llm: Union[str, InstanceOf[BaseLLM], Any] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
system_template: Optional[str] = Field(
|
system_template: Optional[str] = Field(
|
||||||
@@ -117,7 +117,7 @@ class Agent(BaseAgent):
|
|||||||
self.agent_ops_agent_name = self.role
|
self.agent_ops_agent_name = self.role
|
||||||
|
|
||||||
self.llm = create_llm(self.llm)
|
self.llm = create_llm(self.llm)
|
||||||
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
|
if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM):
|
||||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||||
|
|
||||||
if not self.agent_executor:
|
if not self.agent_executor:
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from crewai.agents.cache import CacheHandler
|
|||||||
from crewai.crews.crew_output import CrewOutput
|
from crewai.crews.crew_output import CrewOutput
|
||||||
from crewai.knowledge.knowledge import Knowledge
|
from crewai.knowledge.knowledge import Knowledge
|
||||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.llm import LLM
|
from crewai.llm import BaseLLM, LLM
|
||||||
from crewai.memory.entity.entity_memory import EntityMemory
|
from crewai.memory.entity.entity_memory import EntityMemory
|
||||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||||
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
from crewai.memory.short_term.short_term_memory import ShortTermMemory
|
||||||
@@ -150,7 +150,7 @@ class Crew(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Metrics for the LLM usage during all tasks execution.",
|
description="Metrics for the LLM usage during all tasks execution.",
|
||||||
)
|
)
|
||||||
manager_llm: Optional[Any] = Field(
|
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
manager_agent: Optional[BaseAgent] = Field(
|
manager_agent: Optional[BaseAgent] = Field(
|
||||||
@@ -196,7 +196,7 @@ class Crew(BaseModel):
|
|||||||
default=False,
|
default=False,
|
||||||
description="Plan the crew execution and add the plan to the crew.",
|
description="Plan the crew execution and add the plan to the crew.",
|
||||||
)
|
)
|
||||||
planning_llm: Optional[Any] = Field(
|
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Language model that will run the AgentPlanner if planning is True.",
|
description="Language model that will run the AgentPlanner if planning is True.",
|
||||||
)
|
)
|
||||||
@@ -212,7 +212,7 @@ class Crew(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||||
)
|
)
|
||||||
chat_llm: Optional[Any] = Field(
|
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="LLM used to handle chatting with the crew.",
|
description="LLM used to handle chatting with the crew.",
|
||||||
)
|
)
|
||||||
@@ -1193,7 +1193,7 @@ class Crew(BaseModel):
|
|||||||
def test(
|
def test(
|
||||||
self,
|
self,
|
||||||
n_iterations: int,
|
n_iterations: int,
|
||||||
eval_llm: Union[str, InstanceOf[LLM]],
|
eval_llm: Union[str, InstanceOf[BaseLLM]],
|
||||||
inputs: Optional[Dict[str, Any]] = None,
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||||
|
|
||||||
@@ -34,6 +35,78 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLM(ABC):
|
||||||
|
"""Abstract base class for LLM implementations.
|
||||||
|
|
||||||
|
This class defines the interface that all LLM implementations must follow.
|
||||||
|
Users can extend this class to create custom LLM implementations that don't
|
||||||
|
rely on litellm's authentication mechanism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the BaseLLM with default attributes.
|
||||||
|
|
||||||
|
This constructor sets default values for attributes that are expected
|
||||||
|
by the CrewAgentExecutor and other components.
|
||||||
|
"""
|
||||||
|
self.stop = []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
"""Call the LLM with the given messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages for the LLM.
|
||||||
|
Can be a string or list of message dictionaries.
|
||||||
|
If string, it will be converted to a single user message.
|
||||||
|
If list, each dict must have 'role' and 'content' keys.
|
||||||
|
tools: Optional list of tool schemas for function calling.
|
||||||
|
Each tool should define its name, description, and parameters.
|
||||||
|
callbacks: Optional list of callback functions to be executed
|
||||||
|
during and after the LLM call.
|
||||||
|
available_functions: Optional dict mapping function names to callables
|
||||||
|
that can be invoked by the LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either a text response from the LLM (str) or
|
||||||
|
the result of a tool function call (Any).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Check if the LLM supports function calling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the LLM supports function calling, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Check if the LLM supports stop words.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the LLM supports stop words, False otherwise.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Get the context window size of the LLM.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The context window size as an integer.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class FilteredStream:
|
class FilteredStream:
|
||||||
def __init__(self, original_stream):
|
def __init__(self, original_stream):
|
||||||
self._original_stream = original_stream
|
self._original_stream = original_stream
|
||||||
@@ -126,7 +199,7 @@ def suppress_warnings():
|
|||||||
sys.stderr = old_stderr
|
sys.stderr = old_stderr
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM(BaseLLM):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|||||||
@@ -2,28 +2,28 @@ import os
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
|
||||||
from crewai.llm import LLM
|
from crewai.llm import BaseLLM, LLM
|
||||||
|
|
||||||
|
|
||||||
def create_llm(
|
def create_llm(
|
||||||
llm_value: Union[str, LLM, Any, None] = None,
|
llm_value: Union[str, BaseLLM, Any, None] = None,
|
||||||
) -> Optional[LLM]:
|
) -> Optional[BaseLLM]:
|
||||||
"""
|
"""
|
||||||
Creates or returns an LLM instance based on the given llm_value.
|
Creates or returns an LLM instance based on the given llm_value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
llm_value (str | LLM | Any | None):
|
llm_value (str | BaseLLM | Any | None):
|
||||||
- str: The model name (e.g., "gpt-4").
|
- str: The model name (e.g., "gpt-4").
|
||||||
- LLM: Already instantiated LLM, returned as-is.
|
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
|
||||||
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
||||||
- None: Use environment-based or fallback default model.
|
- None: Use environment-based or fallback default model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An LLM instance if successful, or None if something fails.
|
A BaseLLM instance if successful, or None if something fails.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1) If llm_value is already an LLM object, return it directly
|
# 1) If llm_value is already a BaseLLM object, return it directly
|
||||||
if isinstance(llm_value, LLM):
|
if isinstance(llm_value, BaseLLM):
|
||||||
return llm_value
|
return llm_value
|
||||||
|
|
||||||
# 2) If llm_value is a string (model name)
|
# 2) If llm_value is a string (model name)
|
||||||
|
|||||||
111
tests/custom_llm_test.py
Normal file
111
tests/custom_llm_test.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import pytest
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from crewai.llm import BaseLLM
|
||||||
|
from crewai.utilities.llm_utils import create_llm
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLLM(BaseLLM):
|
||||||
|
"""Custom LLM implementation for testing."""
|
||||||
|
|
||||||
|
def __init__(self, response: str = "Custom LLM response"):
|
||||||
|
self.response = response
|
||||||
|
self.calls = []
|
||||||
|
self.stop = []
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
"""Record the call and return the predefined response."""
|
||||||
|
self.calls.append({
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
"available_functions": available_functions
|
||||||
|
})
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Return True to indicate that function calling is supported."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Return True to indicate that stop words are supported."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Return a default context window size."""
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_llm_implementation():
|
||||||
|
"""Test that a custom LLM implementation works with create_llm."""
|
||||||
|
custom_llm = CustomLLM(response="The answer is 42")
|
||||||
|
|
||||||
|
# Test that create_llm returns the custom LLM instance directly
|
||||||
|
result_llm = create_llm(custom_llm)
|
||||||
|
|
||||||
|
assert result_llm is custom_llm
|
||||||
|
|
||||||
|
# Test calling the custom LLM
|
||||||
|
response = result_llm.call("What is the answer to life, the universe, and everything?")
|
||||||
|
|
||||||
|
# Verify that the custom LLM was called
|
||||||
|
assert len(custom_llm.calls) > 0
|
||||||
|
# Verify that the response from the custom LLM was used
|
||||||
|
assert response == "The answer is 42"
|
||||||
|
|
||||||
|
|
||||||
|
class JWTAuthLLM(BaseLLM):
|
||||||
|
def __init__(self, jwt_token: str):
|
||||||
|
self.jwt_token = jwt_token
|
||||||
|
self.calls = []
|
||||||
|
self.stop = []
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
self.calls.append({
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
"available_functions": available_functions
|
||||||
|
})
|
||||||
|
# In a real implementation, this would use the JWT token to authenticate
|
||||||
|
# with an external service
|
||||||
|
return "Response from JWT-authenticated LLM"
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_llm_with_jwt_auth():
|
||||||
|
"""Test a custom LLM implementation with JWT authentication."""
|
||||||
|
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
|
||||||
|
|
||||||
|
# Test that create_llm returns the JWT-authenticated LLM instance directly
|
||||||
|
result_llm = create_llm(jwt_llm)
|
||||||
|
|
||||||
|
assert result_llm is jwt_llm
|
||||||
|
|
||||||
|
# Test calling the JWT-authenticated LLM
|
||||||
|
response = result_llm.call("Test message")
|
||||||
|
|
||||||
|
# Verify that the JWT-authenticated LLM was called
|
||||||
|
assert len(jwt_llm.calls) > 0
|
||||||
|
# Verify that the response from the JWT-authenticated LLM was used
|
||||||
|
assert response == "Response from JWT-authenticated LLM"
|
||||||
Reference in New Issue
Block a user