diff --git a/src/crewai/llm.py b/src/crewai/llm.py index ada5c9bf3..649de381a 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -117,9 +117,19 @@ def suppress_warnings(): class LLM: + """A wrapper around LiteLLM providing a unified interface for various LLM providers. + + Args: + model (str): The identifier of the LLM model to use + location (Optional[str]): Optional location for provider-specific settings (e.g., Vertex AI region) + timeout (Optional[Union[float, int]]): Maximum time to wait for the model response + temperature (Optional[float]): Controls randomness in the model's output + top_p (Optional[float]): Controls diversity of the model's output + """ def __init__( self, model: str, + location: Optional[str] = None, timeout: Optional[Union[float, int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, @@ -143,6 +153,7 @@ class LLM: **kwargs, ): self.model = model + self.location = location self.timeout = timeout self.temperature = temperature self.top_p = top_p @@ -166,6 +177,10 @@ class LLM: self.additional_params = kwargs self.is_anthropic = self._is_anthropic_model(model) + # Set vertex location if provided for vertex models + if self.location and ("vertex" in self.model.lower() or self.model.startswith("gemini-")): + litellm.vertex_location = self.location + litellm.drop_params = True # Normalize self.stop to always be a List[str] diff --git a/tests/llm_test.py b/tests/llm_test.py index 2e5faf774..fe2c5cc30 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -3,6 +3,7 @@ from time import sleep from unittest.mock import MagicMock, patch import pytest +import litellm from pydantic import BaseModel from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess @@ -12,6 +13,21 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler # TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_vertex_ai_location(): + """Test that Vertex AI location setting is respected.""" + location = "europe-west4" + llm = LLM( + model="vertex_ai/gemini-2.0-flash", + location=location, + ) + + # Verify location is set correctly + assert litellm.vertex_location == location + + # Reset location after test + litellm.vertex_location = None + @pytest.mark.vcr(filter_headers=["authorization"]) def test_llm_callback_replacement(): llm1 = LLM(model="gpt-4o-mini")