mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 07:08:31 +00:00
Compare commits
2 Commits
devin/1744
...
devin/1746
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8476fb2c64 | ||
|
|
8f3162b8e8 |
213
docs/multiple_model_config.md
Normal file
213
docs/multiple_model_config.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# Multiple Model Configuration in CrewAI
|
||||
|
||||
CrewAI now supports configuring multiple language models with different API keys and configurations. This feature allows you to:
|
||||
|
||||
1. Load-balance across multiple model deployments
|
||||
2. Set up fallback models in case of rate limits or errors
|
||||
3. Configure different routing strategies for model selection
|
||||
4. Maintain fine-grained control over model selection and usage
|
||||
|
||||
## Basic Usage
|
||||
|
||||
You can configure multiple models at the agent level:
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
|
||||
# Define model configurations
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini", # Required: model name must be specified here
|
||||
"api_key": "your-openai-api-key-1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo", # Required: model name must be specified here
|
||||
"api_key": "your-openai-api-key-2"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-sonnet-20240229",
|
||||
"litellm_params": {
|
||||
"model": "claude-3-sonnet-20240229", # Required: model name must be specified here
|
||||
"api_key": "your-anthropic-api-key"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Create an agent with multiple model configurations
|
||||
agent = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze the data and provide insights",
|
||||
backstory="You are an expert data analyst with years of experience.",
|
||||
model_list=model_list,
|
||||
routing_strategy="simple-shuffle" # Optional routing strategy
|
||||
)
|
||||
```
|
||||
|
||||
## Routing Strategies
|
||||
|
||||
CrewAI supports the following routing strategies for precise control over model selection:
|
||||
|
||||
- `simple-shuffle`: Randomly selects a model from the list
|
||||
- `least-busy`: Routes to the model with the least number of ongoing requests
|
||||
- `usage-based`: Routes based on token usage across models
|
||||
- `latency-based`: Routes to the model with the lowest latency
|
||||
- `cost-based`: Routes to the model with the lowest cost
|
||||
|
||||
Example with latency-based routing:
|
||||
|
||||
```python
|
||||
agent = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Analyze the data and provide insights",
|
||||
backstory="You are an expert data analyst with years of experience.",
|
||||
model_list=model_list,
|
||||
routing_strategy="latency-based"
|
||||
)
|
||||
```
|
||||
|
||||
## Direct LLM Configuration
|
||||
|
||||
You can also configure multiple models directly with the LLM class for more flexibility:
|
||||
|
||||
```python
|
||||
from crewai import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-4o-mini",
|
||||
model_list=model_list,
|
||||
routing_strategy="simple-shuffle"
|
||||
)
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
For more advanced configurations, you can specify additional parameters for each model to handle complex use cases:
|
||||
|
||||
```python
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini", # Required: model name must be specified here
|
||||
"api_key": "your-openai-api-key-1",
|
||||
"temperature": 0.7
|
||||
},
|
||||
"tpm": 100000, # Tokens per minute limit
|
||||
"rpm": 1000 # Requests per minute limit
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo", # Required: model name must be specified here
|
||||
"api_key": "your-openai-api-key-2",
|
||||
"temperature": 0.5
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Error Handling and Troubleshooting
|
||||
|
||||
When working with multiple model configurations, you may encounter various issues. Here are some common problems and their solutions:
|
||||
|
||||
### Missing Required Parameters
|
||||
|
||||
**Problem**: Router initialization fails with an error about missing parameters.
|
||||
|
||||
**Solution**: Ensure each model configuration in `model_list` includes both `model_name` and `litellm_params` with the required `model` parameter:
|
||||
|
||||
```python
|
||||
# Correct configuration
|
||||
model_config = {
|
||||
"model_name": "gpt-4o-mini", # Required
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini", # Required
|
||||
"api_key": "your-api-key"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Invalid Routing Strategy
|
||||
|
||||
**Problem**: Error when specifying an unsupported routing strategy.
|
||||
|
||||
**Solution**: Use only the supported routing strategies:
|
||||
|
||||
```python
|
||||
# Valid routing strategies
|
||||
valid_strategies = [
|
||||
"simple-shuffle",
|
||||
"least-busy",
|
||||
"usage-based",
|
||||
"latency-based",
|
||||
"cost-based"
|
||||
]
|
||||
```
|
||||
|
||||
### API Key Authentication Errors
|
||||
|
||||
**Problem**: Authentication errors when making API calls.
|
||||
|
||||
**Solution**: Verify that all API keys are valid and have the necessary permissions:
|
||||
|
||||
```python
|
||||
# Check environment variables first
|
||||
import os
|
||||
os.environ.get("OPENAI_API_KEY") # Should be set if using OpenAI models
|
||||
|
||||
# Or explicitly provide in the configuration
|
||||
model_list = [{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "valid-api-key-here" # Ensure this is correct
|
||||
}
|
||||
}]
|
||||
```
|
||||
|
||||
### Rate Limit Handling
|
||||
|
||||
**Problem**: Encountering rate limits with multiple models.
|
||||
|
||||
**Solution**: Configure rate limits and implement fallback mechanisms:
|
||||
|
||||
```python
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "primary-model",
|
||||
"litellm_params": {"model": "primary-model", "api_key": "key1"},
|
||||
"rpm": 100 # Requests per minute
|
||||
},
|
||||
{
|
||||
"model_name": "fallback-model",
|
||||
"litellm_params": {"model": "fallback-model", "api_key": "key2"}
|
||||
}
|
||||
]
|
||||
|
||||
# Configure with fallback
|
||||
llm = LLM(
|
||||
model="primary-model",
|
||||
model_list=model_list,
|
||||
routing_strategy="least-busy" # Will route to fallback when primary is busy
|
||||
)
|
||||
```
|
||||
|
||||
### Debugging Router Issues
|
||||
|
||||
If you're experiencing issues with the router, you can enable verbose logging to get more information:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Then initialize your LLM
|
||||
llm = LLM(model="gpt-4o-mini", model_list=model_list)
|
||||
```
|
||||
|
||||
This feature leverages litellm's Router functionality under the hood, providing robust load balancing and fallback capabilities for your CrewAI agents. The implementation ensures predictability and consistency in model selection while maintaining security through proper API key management.
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator, field_validator
|
||||
|
||||
from crewai.agents import CacheHandler
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
@@ -86,7 +87,20 @@ class Agent(BaseAgent):
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: Optional[Any] = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
description="Language model that will handle function calling for the agent.", default=None
|
||||
)
|
||||
class RoutingStrategy(str, Enum):
|
||||
SIMPLE_SHUFFLE = "simple-shuffle"
|
||||
LEAST_BUSY = "least-busy"
|
||||
USAGE_BASED = "usage-based"
|
||||
LATENCY_BASED = "latency-based"
|
||||
COST_BASED = "cost-based"
|
||||
|
||||
model_list: Optional[List[Dict[str, Any]]] = Field(
|
||||
default=None, description="List of model configurations for routing between multiple models."
|
||||
)
|
||||
routing_strategy: Optional[RoutingStrategy] = Field(
|
||||
default=None, description="Strategy for routing between multiple models (e.g., 'simple-shuffle', 'least-busy', 'usage-based', 'latency-based', 'cost-based')."
|
||||
)
|
||||
system_template: Optional[str] = Field(
|
||||
default=None, description="System format for the agent."
|
||||
@@ -148,10 +162,17 @@ class Agent(BaseAgent):
|
||||
# Handle different cases for self.llm
|
||||
if isinstance(self.llm, str):
|
||||
# If it's a string, create an LLM instance
|
||||
self.llm = LLM(model=self.llm)
|
||||
self.llm = LLM(
|
||||
model=self.llm,
|
||||
model_list=self.model_list,
|
||||
routing_strategy=self.routing_strategy
|
||||
)
|
||||
elif isinstance(self.llm, LLM):
|
||||
# If it's already an LLM instance, keep it as is
|
||||
pass
|
||||
if self.model_list and not getattr(self.llm, "model_list", None):
|
||||
self.llm.model_list = self.model_list
|
||||
self.llm.routing_strategy = self.routing_strategy
|
||||
self.llm._initialize_router()
|
||||
elif self.llm is None:
|
||||
# Determine the model name from environment variables or use default
|
||||
model_name = (
|
||||
@@ -159,7 +180,11 @@ class Agent(BaseAgent):
|
||||
or os.environ.get("MODEL")
|
||||
or "gpt-4o-mini"
|
||||
)
|
||||
llm_params = {"model": model_name}
|
||||
llm_params = {
|
||||
"model": model_name,
|
||||
"model_list": self.model_list,
|
||||
"routing_strategy": self.routing_strategy
|
||||
}
|
||||
|
||||
api_base = os.environ.get("OPENAI_API_BASE") or os.environ.get(
|
||||
"OPENAI_BASE_URL"
|
||||
@@ -207,6 +232,8 @@ class Agent(BaseAgent):
|
||||
"api_key": getattr(self.llm, "api_key", None),
|
||||
"base_url": getattr(self.llm, "base_url", None),
|
||||
"organization": getattr(self.llm, "organization", None),
|
||||
"model_list": self.model_list,
|
||||
"routing_strategy": self.routing_strategy,
|
||||
}
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
llm_params = {k: v for k, v in llm_params.items() if v is not None}
|
||||
|
||||
17
src/crewai/agents/cache/cache_handler.py
vendored
17
src/crewai/agents/cache/cache_handler.py
vendored
@@ -1,28 +1,15 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import threading
|
||||
from threading import local
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
_thread_local = local()
|
||||
|
||||
|
||||
class CacheHandler(BaseModel):
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
_cache: Dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def _get_lock(self):
|
||||
"""Get a thread-local lock to avoid pickling issues."""
|
||||
if not hasattr(_thread_local, "cache_lock"):
|
||||
_thread_local.cache_lock = threading.Lock()
|
||||
return _thread_local.cache_lock
|
||||
|
||||
def add(self, tool, input, output):
|
||||
with self._get_lock():
|
||||
self._cache[f"{tool}-{input}"] = output
|
||||
self._cache[f"{tool}-{input}"] = output
|
||||
|
||||
def read(self, tool, input) -> Optional[str]:
|
||||
with self._get_lock():
|
||||
return self._cache.get(f"{tool}-{input}")
|
||||
return self._cache.get(f"{tool}-{input}")
|
||||
|
||||
@@ -88,7 +88,7 @@ class Crew(BaseModel):
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default=CacheHandler())
|
||||
_short_term_memory: Optional[InstanceOf[ShortTermMemory]] = PrivateAttr()
|
||||
_long_term_memory: Optional[InstanceOf[LongTermMemory]] = PrivateAttr()
|
||||
_entity_memory: Optional[InstanceOf[EntityMemory]] = PrivateAttr()
|
||||
|
||||
@@ -7,12 +7,17 @@ from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm import Router as LiteLLMRouter
|
||||
from litellm import get_supported_openai_params
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
logger = Logger(verbose=True)
|
||||
|
||||
|
||||
class FilteredStream:
|
||||
def __init__(self, original_stream):
|
||||
@@ -113,6 +118,8 @@ class LLM:
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
model_list: Optional[List[Dict[str, Any]]] = None,
|
||||
routing_strategy: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
@@ -136,11 +143,50 @@ class LLM:
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.kwargs = kwargs
|
||||
self.model_list = model_list
|
||||
self.routing_strategy = routing_strategy
|
||||
self.router = None
|
||||
|
||||
litellm.drop_params = True
|
||||
litellm.set_verbose = False
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
|
||||
if self.model_list:
|
||||
self._initialize_router()
|
||||
|
||||
def _initialize_router(self):
|
||||
"""
|
||||
Initialize the litellm Router with the provided model_list and routing_strategy.
|
||||
"""
|
||||
try:
|
||||
router_kwargs = {}
|
||||
if self.routing_strategy:
|
||||
valid_strategies = ["simple-shuffle", "least-busy", "usage-based", "latency-based", "cost-based"]
|
||||
if self.routing_strategy not in valid_strategies:
|
||||
raise ValueError(f"Invalid routing strategy: {self.routing_strategy}. Valid options are: {', '.join(valid_strategies)}")
|
||||
router_kwargs["routing_strategy"] = self.routing_strategy
|
||||
|
||||
self.router = LiteLLMRouter(
|
||||
model_list=self.model_list,
|
||||
**router_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.log("error", f"Failed to initialize router: {str(e)}")
|
||||
raise RuntimeError(f"Router initialization failed: {str(e)}")
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
|
||||
def _execute_router_call(self, params):
|
||||
"""
|
||||
Execute a call to the router with retry logic for handling transient issues.
|
||||
|
||||
Args:
|
||||
params: Parameters to pass to the router completion method
|
||||
|
||||
Returns:
|
||||
The response from the router
|
||||
"""
|
||||
return self.router.completion(model=self.model, **params)
|
||||
|
||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
||||
with suppress_warnings():
|
||||
@@ -149,7 +195,6 @@ class LLM:
|
||||
|
||||
try:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
@@ -164,9 +209,6 @@ class LLM:
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
**self.kwargs,
|
||||
}
|
||||
@@ -174,7 +216,17 @@ class LLM:
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
response = litellm.completion(**params)
|
||||
if self.router:
|
||||
response = self._execute_router_call(params)
|
||||
else:
|
||||
params.update({
|
||||
"model": self.model,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
})
|
||||
response = litellm.completion(**params)
|
||||
|
||||
return response["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
if not LLMContextLengthExceededException(
|
||||
|
||||
@@ -4,15 +4,11 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from importlib.metadata import version
|
||||
from threading import local
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
_thread_local = local()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
@@ -80,20 +76,12 @@ class Telemetry:
|
||||
raise # Re-raise the exception to not interfere with system signals
|
||||
self.ready = False
|
||||
|
||||
def _get_lock(self):
|
||||
"""Get a thread-local lock to avoid pickling issues."""
|
||||
if not hasattr(_thread_local, "telemetry_lock"):
|
||||
_thread_local.telemetry_lock = threading.Lock()
|
||||
return _thread_local.telemetry_lock
|
||||
|
||||
def set_tracer(self):
|
||||
if self.ready and not self.trace_set:
|
||||
try:
|
||||
with self._get_lock():
|
||||
if not self.trace_set: # Double-check to avoid race condition
|
||||
with suppress_warnings():
|
||||
trace.set_tracer_provider(self.provider)
|
||||
self.trace_set = True
|
||||
with suppress_warnings():
|
||||
trace.set_tracer_provider(self.provider)
|
||||
self.trace_set = True
|
||||
except Exception:
|
||||
self.ready = False
|
||||
self.trace_set = False
|
||||
@@ -102,8 +90,7 @@ class Telemetry:
|
||||
if not self.ready:
|
||||
return
|
||||
try:
|
||||
with self._get_lock():
|
||||
operation()
|
||||
operation()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
|
||||
class MockLLM:
|
||||
"""Mock LLM for testing."""
|
||||
def __init__(self, model="gpt-3.5-turbo", **kwargs):
|
||||
self.model = model
|
||||
self.stop = None
|
||||
self.timeout = None
|
||||
self.temperature = None
|
||||
self.top_p = None
|
||||
self.n = None
|
||||
self.max_completion_tokens = None
|
||||
self.max_tokens = None
|
||||
self.presence_penalty = None
|
||||
self.frequency_penalty = None
|
||||
self.logit_bias = None
|
||||
self.response_format = None
|
||||
self.seed = None
|
||||
self.logprobs = None
|
||||
self.top_logprobs = None
|
||||
self.base_url = None
|
||||
self.api_version = None
|
||||
self.api_key = None
|
||||
self.callbacks = []
|
||||
self.context_window_size = 8192
|
||||
self.kwargs = {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def complete(self, prompt, **kwargs):
|
||||
"""Mock completion method."""
|
||||
return f"Mock response for: {prompt[:20]}..."
|
||||
|
||||
def chat_completion(self, messages, **kwargs):
|
||||
"""Mock chat completion method."""
|
||||
return {"choices": [{"message": {"content": "Mock response"}}]}
|
||||
|
||||
def function_call(self, messages, functions, **kwargs):
|
||||
"""Mock function call method."""
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"content": "Mock response",
|
||||
"function_call": {
|
||||
"name": "test_function",
|
||||
"arguments": '{"arg1": "value1"}'
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def supports_stop_words(self):
|
||||
"""Mock supports_stop_words method."""
|
||||
return False
|
||||
|
||||
def supports_function_calling(self):
|
||||
"""Mock supports_function_calling method."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self):
|
||||
"""Mock get_context_window_size method."""
|
||||
return self.context_window_size
|
||||
|
||||
def call(self, messages, callbacks=None):
|
||||
"""Mock call method."""
|
||||
return "Mock response from call method"
|
||||
|
||||
def set_callbacks(self, callbacks):
|
||||
"""Mock set_callbacks method."""
|
||||
self.callbacks = callbacks
|
||||
|
||||
def set_env_callbacks(self):
|
||||
"""Mock set_env_callbacks method."""
|
||||
pass
|
||||
|
||||
|
||||
def create_test_crew():
|
||||
"""Create a simple test crew for concurrency testing."""
|
||||
with patch("crewai.agent.LLM", MockLLM):
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test concurrent execution",
|
||||
backstory="I am a test agent for concurrent execution",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task for concurrent execution",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
return crew
|
||||
|
||||
|
||||
def test_threading_concurrency():
|
||||
"""Test concurrent execution using ThreadPoolExecutor."""
|
||||
num_threads = 5
|
||||
results = []
|
||||
|
||||
def generate_response(idx):
|
||||
try:
|
||||
crew = create_test_crew()
|
||||
with patch("crewai.agent.LLM", MockLLM):
|
||||
output = crew.kickoff(inputs={"test_input": f"input_{idx}"})
|
||||
return output
|
||||
except Exception as e:
|
||||
pytest.fail(f"Exception in thread {idx}: {e}")
|
||||
return None
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(generate_response, i) for i in range(num_threads)]
|
||||
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
assert result is not None
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == num_threads
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asyncio_concurrency():
|
||||
"""Test concurrent execution using asyncio."""
|
||||
num_tasks = 5
|
||||
sem = asyncio.Semaphore(num_tasks)
|
||||
|
||||
async def generate_response_async(idx):
|
||||
async with sem:
|
||||
try:
|
||||
crew = create_test_crew()
|
||||
with patch("crewai.agent.LLM", MockLLM):
|
||||
output = await crew.kickoff_async(inputs={"test_input": f"input_{idx}"})
|
||||
return output
|
||||
except Exception as e:
|
||||
pytest.fail(f"Exception in task {idx}: {e}")
|
||||
return None
|
||||
|
||||
tasks = [generate_response_async(i) for i in range(num_tasks)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == num_tasks
|
||||
assert all(result is not None for result in results)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_asyncio_concurrency():
|
||||
"""Extended test for asyncio concurrency with more iterations."""
|
||||
num_tasks = 5 # Reduced from 10 for faster testing
|
||||
iterations = 2 # Reduced from 3 for faster testing
|
||||
sem = asyncio.Semaphore(num_tasks)
|
||||
|
||||
async def generate_response_async(idx):
|
||||
async with sem:
|
||||
crew = create_test_crew()
|
||||
for i in range(iterations):
|
||||
try:
|
||||
with patch("crewai.agent.LLM", MockLLM):
|
||||
output = await crew.kickoff_async(
|
||||
inputs={"test_input": f"input_{idx}_{i}"}
|
||||
)
|
||||
assert output is not None
|
||||
except Exception as e:
|
||||
pytest.fail(f"Exception in task {idx}, iteration {i}: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
tasks = [generate_response_async(i) for i in range(num_tasks)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert all(results)
|
||||
246
tests/multiple_model_config_test.py
Normal file
246
tests/multiple_model_config_test.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.agent import Agent
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch("litellm.Router")
|
||||
@patch.object(LLM, '_initialize_router')
|
||||
def test_llm_with_model_list(mock_initialize_router, mock_router):
|
||||
"""Test that LLM can be initialized with a model_list for multiple model configurations."""
|
||||
mock_initialize_router.return_value = None
|
||||
|
||||
mock_router_instance = MagicMock()
|
||||
mock_router.return_value = mock_router_instance
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "test-key-2"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
llm = LLM(model="gpt-4o-mini", model_list=model_list)
|
||||
llm.router = mock_router_instance
|
||||
|
||||
assert llm.model == "gpt-4o-mini"
|
||||
assert llm.model_list == model_list
|
||||
assert llm.router is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch("litellm.Router")
|
||||
@patch.object(LLM, '_initialize_router')
|
||||
def test_llm_with_routing_strategy(mock_initialize_router, mock_router):
|
||||
"""Test that LLM can be initialized with a routing strategy."""
|
||||
mock_initialize_router.return_value = None
|
||||
|
||||
mock_router_instance = MagicMock()
|
||||
mock_router.return_value = mock_router_instance
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "test-key-2"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-4o-mini",
|
||||
model_list=model_list,
|
||||
routing_strategy="simple-shuffle"
|
||||
)
|
||||
llm.router = mock_router_instance
|
||||
|
||||
assert llm.routing_strategy == "simple-shuffle"
|
||||
assert llm.router is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch("litellm.Router")
|
||||
@patch.object(LLM, '_initialize_router')
|
||||
def test_agent_with_model_list(mock_initialize_router, mock_router):
|
||||
"""Test that Agent can be initialized with a model_list for multiple model configurations."""
|
||||
mock_initialize_router.return_value = None
|
||||
|
||||
mock_router_instance = MagicMock()
|
||||
mock_router.return_value = mock_router_instance
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": "test-key-2"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(Agent, 'post_init_setup', wraps=Agent.post_init_setup) as mock_post_init:
|
||||
agent = Agent(
|
||||
role="test",
|
||||
goal="test",
|
||||
backstory="test",
|
||||
model_list=model_list
|
||||
)
|
||||
|
||||
agent.llm.router = mock_router_instance
|
||||
|
||||
assert agent.model_list == model_list
|
||||
assert agent.llm.model_list == model_list
|
||||
assert agent.llm.router is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch("litellm.Router")
|
||||
@patch.object(LLM, '_initialize_router')
|
||||
def test_llm_call_with_router(mock_initialize_router, mock_router):
|
||||
"""Test that LLM.call uses the router when model_list is provided."""
|
||||
mock_initialize_router.return_value = None
|
||||
|
||||
mock_router_instance = MagicMock()
|
||||
mock_router.return_value = mock_router_instance
|
||||
|
||||
mock_response = {
|
||||
"choices": [{"message": {"content": "Test response"}}]
|
||||
}
|
||||
mock_router_instance.completion.return_value = mock_response
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Create LLM with model_list
|
||||
llm = LLM(model="gpt-4o-mini", model_list=model_list)
|
||||
|
||||
llm.router = mock_router_instance
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = llm.call(messages)
|
||||
|
||||
mock_router_instance.completion.assert_called_once()
|
||||
assert response == "Test response"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch("litellm.completion")
|
||||
def test_llm_call_without_router(mock_completion):
|
||||
"""Test that LLM.call uses litellm.completion when no model_list is provided."""
|
||||
mock_response = {
|
||||
"choices": [{"message": {"content": "Test response"}}]
|
||||
}
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
response = llm.call(messages)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
assert response == "Test response"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_with_invalid_routing_strategy():
|
||||
"""Test that LLM initialization raises an error with an invalid routing strategy."""
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
LLM(
|
||||
model="gpt-4o-mini",
|
||||
model_list=model_list,
|
||||
routing_strategy="invalid-strategy"
|
||||
)
|
||||
|
||||
assert "Invalid routing strategy" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_with_invalid_routing_strategy():
|
||||
"""Test that Agent initialization raises an error with an invalid routing strategy."""
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o-mini",
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
Agent(
|
||||
role="test",
|
||||
goal="test",
|
||||
backstory="test",
|
||||
model_list=model_list,
|
||||
routing_strategy="invalid-strategy"
|
||||
)
|
||||
|
||||
assert "Input should be" in str(exc_info.value)
|
||||
assert "simple-shuffle" in str(exc_info.value)
|
||||
assert "least-busy" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@patch.object(LLM, '_initialize_router')
|
||||
def test_llm_with_missing_model_in_litellm_params(mock_initialize_router):
|
||||
"""Test that LLM initialization raises an error when model is missing in litellm_params."""
|
||||
mock_initialize_router.side_effect = RuntimeError("Router initialization failed: Missing required 'model' in litellm_params")
|
||||
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-4o-mini",
|
||||
"litellm_params": {
|
||||
"api_key": "test-key-1"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
LLM(model="gpt-4o-mini", model_list=model_list)
|
||||
|
||||
assert "Router initialization failed" in str(exc_info.value)
|
||||
Reference in New Issue
Block a user