mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Add support for multiple model configurations with litellm Router (#2808)
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
115
docs/multiple_model_config.md
Normal file
115
docs/multiple_model_config.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# Multiple Model Configuration in CrewAI
|
||||
|
||||
CrewAI now supports configuring multiple language models with different API keys and configurations. This feature allows you to:
|
||||
|
||||
1. Load-balance across multiple model deployments
|
||||
2. Set up fallback models in case of rate limits or errors
|
||||
3. Configure different routing strategies for model selection
|
||||
4. Maintain fine-grained control over model selection and usage
|
||||
|
||||
## Basic Usage
|
||||
|
||||
You can configure multiple models at the agent level:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
|
||||
# Define model configurations
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini", # Required: model name must be specified here
|
||||
"api_key": "your-openai-api-key-1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo", # Required: model name must be specified here
|
||||
"api_key": "your-openai-api-key-2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-sonnet-20240229",
|
||||
"litellm_params": {
|
||||
"model": "claude-3-sonnet-20240229", # Required: model name must be specified here
|
||||
"api_key": "your-anthropic-api-key"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Create an agent with multiple model configurations
|
||||
agent = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze the data and provide insights",
|
||||
backstory="You are an expert data analyst with years of experience.",
|
||||
model_list=model_list,
|
||||
routing_strategy="simple-shuffle" # Optional routing strategy
|
||||
)
|
||||
```
|
||||
|
||||
## Routing Strategies
|
||||
|
||||
CrewAI supports the following routing strategies for precise control over model selection:
|
||||
|
||||
- `simple-shuffle`: Randomly selects a model from the list
|
||||
- `least-busy`: Routes to the model with the least number of ongoing requests
|
||||
- `usage-based`: Routes based on token usage across models
|
||||
- `latency-based`: Routes to the model with the lowest latency
|
||||
- `cost-based`: Routes to the model with the lowest cost
|
||||
|
||||
Example with latency-based routing:
|
||||
|
||||
```python
|
||||
agent = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze the data and provide insights",
|
||||
backstory="You are an expert data analyst with years of experience.",
|
||||
model_list=model_list,
|
||||
routing_strategy="latency-based"
|
||||
)
|
||||
```
|
||||
|
||||
## Direct LLM Configuration
|
||||
|
||||
You can also configure multiple models directly with the LLM class for more flexibility:
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-4o-mini",
|
||||
model_list=model_list,
|
||||
routing_strategy="simple-shuffle"
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
For more advanced configurations, you can specify additional parameters for each model to handle complex use cases:
|
||||
|
||||
```python
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini", # Required: model name must be specified here
|
||||
"api_key": "your-openai-api-key-1",
|
||||
"temperature": 0.7
|
||||
},
|
||||
"tpm": 100000, # Tokens per minute limit
|
||||
"rpm": 1000 # Requests per minute limit
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo", # Required: model name must be specified here
|
||||
"api_key": "your-openai-api-key-2",
|
||||
"temperature": 0.5
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
This feature leverages litellm's Router functionality under the hood, providing robust load balancing and fallback capabilities for your CrewAI agents. The implementation ensures predictability and consistency in model selection while maintaining security through proper API key management.
|
||||
@@ -86,7 +86,13 @@ class Agent(BaseAgent):
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
description="Language model that will handle function calling for the agent.", default=None
|
||||
)
|
||||
model_list: Optional[List[Dict[str, Any]]] = Field(
|
||||
default=None, description="List of model configurations for routing between multiple models."
|
||||
)
|
||||
routing_strategy: Optional[str] = Field(
|
||||
default=None, description="Strategy for routing between multiple models (e.g., 'simple-shuffle', 'least-busy', 'usage-based', 'latency-based')."
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
default=None, description="System format for the agent."
|
||||
@@ -148,10 +154,17 @@ class Agent(BaseAgent):
|
||||
# Handle different cases for self.llm
|
||||
if isinstance(self.llm, str):
|
||||
# If it's a string, create an LLM instance
|
||||
self.llm = LLM(model=self.llm)
|
||||
self.llm = LLM(
|
||||
model=self.llm,
|
||||
model_list=self.model_list,
|
||||
routing_strategy=self.routing_strategy
|
||||
)
|
||||
elif isinstance(self.llm, LLM):
|
||||
# If it's already an LLM instance, keep it as is
|
||||
pass
|
||||
if self.model_list and not getattr(self.llm, "model_list", None):
|
||||
self.llm.model_list = self.model_list
|
||||
self.llm.routing_strategy = self.routing_strategy
|
||||
self.llm._initialize_router()
|
||||
elif self.llm is None:
|
||||
# Determine the model name from environment variables or use default
|
||||
model_name = (
|
||||
@@ -159,7 +172,11 @@ class Agent(BaseAgent):
|
||||
or os.environ.get("MODEL")
|
||||
or "gpt-4o-mini"
|
||||
)
|
||||
llm_params = {"model": model_name}
|
||||
llm_params = {
|
||||
"model": model_name,
|
||||
"model_list": self.model_list,
|
||||
"routing_strategy": self.routing_strategy
|
||||
}
|
||||
|
||||
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get(
|
||||
"OPENAI_BASE_URL"
|
||||
@@ -207,6 +224,8 @@ class Agent(BaseAgent):
|
||||
"api_key": getattr(self.llm, "api_key", None),
|
||||
"base_url": getattr(self.llm, "base_url", None),
|
||||
"organization": getattr(self.llm, "organization", None),
|
||||
"model_list": self.model_list,
|
||||
"routing_strategy": self.routing_strategy,
|
||||
}
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||
|
||||
@@ -7,6 +7,7 @@ from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm import Router as LiteLLMRouter
|
||||
from litellm import get_supported_openai_params
|
||||
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -113,6 +114,8 @@ class LLM:
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
model_list: Optional[List[Dict[str, Any]]] = None,
|
||||
routing_strategy: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
@@ -136,11 +139,30 @@ class LLM:
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.kwargs = kwargs
|
||||
self.model_list = model_list
|
||||
self.routing_strategy = routing_strategy
|
||||
self.router = None
|
||||
|
||||
litellm.drop_params = True
|
||||
litellm.set_verbose = False
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
|
||||
if self.model_list:
|
||||
self._initialize_router()
|
||||
|
||||
def _initialize_router(self):
|
||||
"""
|
||||
Initialize the litellm Router with the provided model_list and routing_strategy.
|
||||
"""
|
||||
router_kwargs = {}
|
||||
if self.routing_strategy:
|
||||
router_kwargs["routing_strategy"] = self.routing_strategy
|
||||
|
||||
self.router = LiteLLMRouter(
|
||||
model_list=self.model_list,
|
||||
**router_kwargs
|
||||
)
|
||||
|
||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
||||
with suppress_warnings():
|
||||
@@ -149,7 +171,6 @@ class LLM:
|
||||
|
||||
try:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
@@ -164,9 +185,6 @@ class LLM:
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
**self.kwargs,
|
||||
}
|
||||
@@ -174,7 +192,20 @@ class LLM:
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
response = litellm.completion(**params)
|
||||
if self.router:
|
||||
response = self.router.completion(
|
||||
model=self.model,
|
||||
**params
|
||||
)
|
||||
else:
|
||||
params.update({
|
||||
"model": self.model,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
})
|
||||
response = litellm.completion(**params)
|
||||
|
||||
return response["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
if not LLMContextLengthExceededException(
|
||||
|
||||
177
tests/multiple_model_config_test.py
Normal file
177
tests/multiple_model_config_test.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.llm import LLM
|
||||
from crewai.agent import Agent
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch("litellm.Router")
|
||||
@patch.object(LLM, '_initialize_router')
|
||||
def test_llm_with_model_list(mock_initialize_router, mock_router):
|
||||
"""Test that LLM can be initialized with a model_list for multiple model configurations."""
|
||||
mock_initialize_router.return_value = None
|
||||
|
||||
mock_router_instance = MagicMock()
|
||||
mock_router.return_value = mock_router_instance
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "test-key-2"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
llm = LLM(model="gpt-4o-mini", model_list=model_list)
|
||||
llm.router = mock_router_instance
|
||||
|
||||
assert llm.model == "gpt-4o-mini"
|
||||
assert llm.model_list == model_list
|
||||
assert llm.router is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch("litellm.Router")
|
||||
@patch.object(LLM, '_initialize_router')
|
||||
def test_llm_with_routing_strategy(mock_initialize_router, mock_router):
|
||||
"""Test that LLM can be initialized with a routing strategy."""
|
||||
mock_initialize_router.return_value = None
|
||||
|
||||
mock_router_instance = MagicMock()
|
||||
mock_router.return_value = mock_router_instance
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "test-key-2"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-4o-mini",
|
||||
model_list=model_list,
|
||||
routing_strategy="simple-shuffle"
|
||||
)
|
||||
llm.router = mock_router_instance
|
||||
|
||||
assert llm.routing_strategy == "simple-shuffle"
|
||||
assert llm.router is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch("litellm.Router")
|
||||
@patch.object(LLM, '_initialize_router')
|
||||
def test_agent_with_model_list(mock_initialize_router, mock_router):
|
||||
"""Test that Agent can be initialized with a model_list for multiple model configurations."""
|
||||
mock_initialize_router.return_value = None
|
||||
|
||||
mock_router_instance = MagicMock()
|
||||
mock_router.return_value = mock_router_instance
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "test-key-2"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(Agent, 'post_init_setup', wraps=Agent.post_init_setup) as mock_post_init:
|
||||
agent = Agent(
|
||||
role="test",
|
||||
goal="test",
|
||||
backstory="test",
|
||||
model_list=model_list
|
||||
)
|
||||
|
||||
agent.llm.router = mock_router_instance
|
||||
|
||||
assert agent.model_list == model_list
|
||||
assert agent.llm.model_list == model_list
|
||||
assert agent.llm.router is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch("litellm.Router")
|
||||
@patch.object(LLM, '_initialize_router')
|
||||
def test_llm_call_with_router(mock_initialize_router, mock_router):
|
||||
"""Test that LLM.call uses the router when model_list is provided."""
|
||||
mock_initialize_router.return_value = None
|
||||
|
||||
mock_router_instance = MagicMock()
|
||||
mock_router.return_value = mock_router_instance
|
||||
|
||||
mock_response = {
|
||||
"choices": [{"message": {"content": "Test response"}}]
|
||||
}
|
||||
mock_router_instance.completion.return_value = mock_response
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Create LLM with model_list
|
||||
llm = LLM(model="gpt-4o-mini", model_list=model_list)
|
||||
|
||||
llm.router = mock_router_instance
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = llm.call(messages)
|
||||
|
||||
mock_router_instance.completion.assert_called_once()
|
||||
assert response == "Test response"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch("litellm.completion")
|
||||
def test_llm_call_without_router(mock_completion):
|
||||
"""Test that LLM.call uses litellm.completion when no model_list is provided."""
|
||||
mock_response = {
|
||||
"choices": [{"message": {"content": "Test response"}}]
|
||||
}
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = llm.call(messages)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
assert response == "Test response"
|
||||
Reference in New Issue
Block a user