diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 649de381a..4427996b2 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -92,6 +92,19 @@ LLM_CONTEXT_WINDOW_SIZES = { "Meta-Llama-3.2-1B-Instruct": 16384, } +# Common Vertex AI regions +VERTEX_AI_REGIONS = [ + "us-central1", # Iowa + "us-east1", # South Carolina + "us-west1", # Oregon + "europe-west1", # Belgium + "europe-west2", # London + "europe-west3", # Frankfurt + "europe-west4", # Netherlands + "asia-east1", # Taiwan + "asia-southeast1" # Singapore +] + DEFAULT_CONTEXT_WINDOW_SIZE = 8192 CONTEXT_WINDOW_USAGE_RATIO = 0.75 @@ -121,7 +134,8 @@ class LLM: Args: model (str): The identifier of the LLM model to use - location (Optional[str]): Optional location for provider-specific settings (e.g., Vertex AI region) + location (Optional[str]): The GCP region for Vertex AI models (e.g., 'us-central1', 'europe-west4'). + Only applicable for Vertex AI models. timeout (Optional[Union[float, int]]): Maximum time to wait for the model response temperature (Optional[float]): Controls randomness in the model's output top_p (Optional[float]): Controls diversity of the model's output @@ -153,7 +167,18 @@ class LLM: **kwargs, ): self.model = model + + # Validate location parameter + if location is not None: + if not isinstance(location, str): + raise ValueError("Location must be a string when provided") + if self._is_vertex_model(model) and location not in VERTEX_AI_REGIONS: + raise ValueError( + f"Invalid Vertex AI region: {location}. " + f"Supported regions: {', '.join(VERTEX_AI_REGIONS)}" + ) self.location = location + self.timeout = timeout self.temperature = temperature self.top_p = top_p @@ -178,7 +203,7 @@ class LLM: self.is_anthropic = self._is_anthropic_model(model) # Set vertex location if provided for vertex models - if self.location and ("vertex" in self.model.lower() or self.model.startswith("gemini-")): + if self.location and self._is_vertex_model(model): litellm.vertex_location = self.location litellm.drop_params = True @@ -194,6 +219,17 @@ class LLM: self.set_callbacks(callbacks) self.set_env_callbacks() + def _is_vertex_model(self, model: str) -> bool: + """Determine if the model is from Vertex AI provider. + + Args: + model: The model identifier string. + + Returns: + bool: True if the model is from Vertex AI, False otherwise. + """ + return "vertex" in model.lower() or model.startswith("gemini-") + def _is_anthropic_model(self, model: str) -> bool: """Determine if the model is from Anthropic provider. diff --git a/tests/llm_test.py b/tests/llm_test.py index d185ddd0b..7f54619ef 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -13,21 +13,31 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler # TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date +@pytest.mark.parametrize("model,location,expected", [ + ("vertex_ai/gemini-2.0-flash", "europe-west4", "europe-west4"), + ("gpt-4", "europe-west4", None), # Non-vertex model ignores location + ("vertex_ai/gemini-2.0-flash", None, None), # No location provided +]) @pytest.mark.vcr(filter_headers=["authorization"]) -def test_vertex_ai_location(): - """Test that Vertex AI location setting is respected.""" - location = "europe-west4" - llm = LLM( - model="vertex_ai/gemini-2.0-flash", - location=location, - ) - - # Verify location is set correctly - assert litellm.vertex_location == location +def test_vertex_ai_location_setting(model, location, expected): + """Test Vertex AI location setting behavior.""" + llm = LLM(model=model, location=location) + assert litellm.vertex_location == expected # Reset location after test litellm.vertex_location = None +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_vertex_ai_location_validation(): + """Test Vertex AI location validation.""" + # Test invalid location type + with pytest.raises(ValueError, match="Location must be a string"): + LLM(model="vertex_ai/gemini-2.0-flash", location=123) + + # Test invalid region + with pytest.raises(ValueError, match="Invalid Vertex AI region"): + LLM(model="vertex_ai/gemini-2.0-flash", location="invalid-region") + @pytest.mark.vcr(filter_headers=["authorization"]) def test_llm_callback_replacement(): llm1 = LLM(model="gpt-4o-mini")