Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
7ec451dc20 feat: enhance vertex ai location validation
- Add region validation
- Add dedicated vertex model detection
- Expand test coverage
- Improve documentation

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-15 19:00:26 +00:00
Devin AI
d4acbf8adf style: fix import sorting in llm_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-15 18:54:31 +00:00
Devin AI
5c171166d4 fix: respect vertex ai location settings
- Add location parameter to LLM class
- Add test for vertex ai location setting
- Update documentation

Fixes #2141

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-15 18:53:31 +00:00
2 changed files with 77 additions and 0 deletions

View File

@@ -92,6 +92,19 @@ LLM_CONTEXT_WINDOW_SIZES = {
"Meta-Llama-3.2-1B-Instruct": 16384,
}
# Common Vertex AI regions
VERTEX_AI_REGIONS = [
"us-central1", # Iowa
"us-east1", # South Carolina
"us-west1", # Oregon
"europe-west1", # Belgium
"europe-west2", # London
"europe-west3", # Frankfurt
"europe-west4", # Netherlands
"asia-east1", # Taiwan
"asia-southeast1" # Singapore
]
DEFAULT_CONTEXT_WINDOW_SIZE = 8192
CONTEXT_WINDOW_USAGE_RATIO = 0.75
@@ -117,9 +130,20 @@ def suppress_warnings():
class LLM:
"""A wrapper around LiteLLM providing a unified interface for various LLM providers.
Args:
model (str): The identifier of the LLM model to use
location (Optional[str]): The GCP region for Vertex AI models (e.g., 'us-central1', 'europe-west4').
Only applicable for Vertex AI models.
timeout (Optional[Union[float, int]]): Maximum time to wait for the model response
temperature (Optional[float]): Controls randomness in the model's output
top_p (Optional[float]): Controls diversity of the model's output
"""
def __init__(
self,
model: str,
location: Optional[str] = None,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
@@ -143,6 +167,18 @@ class LLM:
**kwargs,
):
self.model = model
# Validate location parameter
if location is not None:
if not isinstance(location, str):
raise ValueError("Location must be a string when provided")
if self._is_vertex_model(model) and location not in VERTEX_AI_REGIONS:
raise ValueError(
f"Invalid Vertex AI region: {location}. "
f"Supported regions: {', '.join(VERTEX_AI_REGIONS)}"
)
self.location = location
self.timeout = timeout
self.temperature = temperature
self.top_p = top_p
@@ -166,6 +202,10 @@ class LLM:
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
# Set vertex location if provided for vertex models
if self.location and self._is_vertex_model(model):
litellm.vertex_location = self.location
litellm.drop_params = True
# Normalize self.stop to always be a List[str]
@@ -179,6 +219,17 @@ class LLM:
self.set_callbacks(callbacks)
self.set_env_callbacks()
def _is_vertex_model(self, model: str) -> bool:
"""Determine if the model is from Vertex AI provider.
Args:
model: The model identifier string.
Returns:
bool: True if the model is from Vertex AI, False otherwise.
"""
return "vertex" in model.lower() or model.startswith("gemini-")
def _is_anthropic_model(self, model: str) -> bool:
"""Determine if the model is from Anthropic provider.

View File

@@ -2,6 +2,7 @@ import os
from time import sleep
from unittest.mock import MagicMock, patch
import litellm
import pytest
from pydantic import BaseModel
@@ -12,6 +13,31 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
@pytest.mark.parametrize("model,location,expected", [
("vertex_ai/gemini-2.0-flash", "europe-west4", "europe-west4"),
("gpt-4", "europe-west4", None), # Non-vertex model ignores location
("vertex_ai/gemini-2.0-flash", None, None), # No location provided
])
@pytest.mark.vcr(filter_headers=["authorization"])
def test_vertex_ai_location_setting(model, location, expected):
"""Test Vertex AI location setting behavior."""
llm = LLM(model=model, location=location)
assert litellm.vertex_location == expected
# Reset location after test
litellm.vertex_location = None
@pytest.mark.vcr(filter_headers=["authorization"])
def test_vertex_ai_location_validation():
"""Test Vertex AI location validation."""
# Test invalid location type
with pytest.raises(ValueError, match="Location must be a string"):
LLM(model="vertex_ai/gemini-2.0-flash", location=123)
# Test invalid region
with pytest.raises(ValueError, match="Invalid Vertex AI region"):
LLM(model="vertex_ai/gemini-2.0-flash", location="invalid-region")
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_callback_replacement():
llm1 = LLM(model="gpt-4o-mini")