From ee308ed3227aacbf244cd75a73f7f4b14495187f Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 10 May 2025 10:51:46 +0000 Subject: [PATCH] Fix Gemini model integration issues (#2803) Co-Authored-By: Joe Moura --- src/crewai/llm.py | 49 ++++++++++++++++++++- tests/llm_test.py | 107 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 2 deletions(-) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index c8c456297..820ffa073 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -322,6 +322,42 @@ class LLM(BaseLLM): ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/") return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES) + def _is_gemini_model(self, model: str) -> bool: + """Determine if the model is from Google Gemini provider. + + Args: + model: The model identifier string. + + Returns: + bool: True if the model is from Gemini, False otherwise. + """ + GEMINI_IDENTIFIERS = ("gemini", "gemma-") + return any(identifier in model.lower() for identifier in GEMINI_IDENTIFIERS) + + def _normalize_gemini_model(self, model: str) -> str: + """Normalize Gemini model name to the format expected by LiteLLM. + + Handles formats like "models/gemini-pro" or "gemini-pro" and converts + them to "gemini/gemini-pro" format. + + Args: + model: The model identifier string. + + Returns: + str: Normalized model name. + """ + if model.startswith("gemini/"): + return model + + if model.startswith("models/"): + model_name = model.split("/", 1)[1] + return f"gemini/{model_name}" + + if self._is_gemini_model(model) and "/" not in model: + return f"gemini/{model}" + + return model + def _prepare_completion_params( self, messages: Union[str, List[Dict[str, str]]], @@ -343,9 +379,18 @@ class LLM(BaseLLM): messages = [{"role": "user", "content": messages}] formatted_messages = self._format_messages_for_provider(messages) - # --- 2) Prepare the parameters for the completion call + model = self.model + if self._is_gemini_model(model): + model = self._normalize_gemini_model(model) + + # --- 2.1) Map GOOGLE_API_KEY to GEMINI_API_KEY if needed + if not os.environ.get("GEMINI_API_KEY") and os.environ.get("GOOGLE_API_KEY"): + os.environ["GEMINI_API_KEY"] = os.environ["GOOGLE_API_KEY"] + logging.info("Mapped GOOGLE_API_KEY to GEMINI_API_KEY for Gemini model") + + # --- 3) Prepare the parameters for the completion call params = { - "model": self.model, + "model": model, "messages": formatted_messages, "timeout": self.timeout, "temperature": self.temperature, diff --git a/tests/llm_test.py b/tests/llm_test.py index f80637c60..b882d392c 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -220,6 +220,37 @@ def test_get_custom_llm_provider_gemini(): assert llm._get_custom_llm_provider() == "gemini" +def test_is_gemini_model(): + """Test the _is_gemini_model method with various model names.""" + llm = LLM(model="gpt-4") # Model doesn't matter for this test + + assert llm._is_gemini_model("gemini-pro") == True + assert llm._is_gemini_model("gemini/gemini-1.5-pro") == True + assert llm._is_gemini_model("models/gemini-pro") == True + assert llm._is_gemini_model("gemma-7b") == True + + # Should not identify as Gemini models + assert llm._is_gemini_model("gpt-4") == False + assert llm._is_gemini_model("claude-3") == False + assert llm._is_gemini_model("mistral-7b") == False + + +def test_normalize_gemini_model(): + """Test the _normalize_gemini_model method with various model formats.""" + llm = LLM(model="gpt-4") # Model doesn't matter for this test + + assert llm._normalize_gemini_model("gemini/gemini-1.5-pro") == "gemini/gemini-1.5-pro" + + assert llm._normalize_gemini_model("models/gemini-pro") == "gemini/gemini-pro" + assert llm._normalize_gemini_model("models/gemini-1.5-flash") == "gemini/gemini-1.5-flash" + + assert llm._normalize_gemini_model("gemini-pro") == "gemini/gemini-pro" + assert llm._normalize_gemini_model("gemini-1.5-flash") == "gemini/gemini-1.5-flash" + + assert llm._normalize_gemini_model("gpt-4") == "gpt-4" + assert llm._normalize_gemini_model("claude-3") == "claude-3" + + def test_get_custom_llm_provider_openai(): llm = LLM(model="gpt-4") assert llm._get_custom_llm_provider() == None @@ -274,6 +305,82 @@ def test_gemini_models(model): assert "Paris" in result +@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"]) +@pytest.mark.parametrize( + "model", + [ + "models/gemini-pro", # Format from issue #2803 + "gemini-pro", # Format without provider prefix + ], +) +def test_gemini_model_normalization(model): + """Test that different Gemini model formats are normalized correctly.""" + llm = LLM(model=model) + + with patch("litellm.completion") as mock_completion: + # Create mocks for response structure + mock_message = MagicMock() + mock_message.content = "Paris" + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + # Set up the mocked completion to return the mock response + mock_completion.return_value = mock_response + + llm.call("What is the capital of France?") + + # Check that the model was normalized correctly in the call to litellm + args, kwargs = mock_completion.call_args + assert kwargs["model"].startswith("gemini/") + assert "gemini-pro" in kwargs["model"] + + +@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"]) +def test_gemini_api_key_mapping(): + """Test that GOOGLE_API_KEY is mapped to GEMINI_API_KEY for Gemini models.""" + original_google_api_key = os.environ.get("GOOGLE_API_KEY") + original_gemini_api_key = os.environ.get("GEMINI_API_KEY") + + try: + # Set up test environment + test_api_key = "test_google_api_key" + os.environ["GOOGLE_API_KEY"] = test_api_key + if "GEMINI_API_KEY" in os.environ: + del os.environ["GEMINI_API_KEY"] + + llm = LLM(model="gemini-pro") + + with patch("litellm.completion") as mock_completion: + # Create mocks for response structure + mock_message = MagicMock() + mock_message.content = "Paris" + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + # Set up the mocked completion to return the mock response + mock_completion.return_value = mock_response + + llm.call("What is the capital of France?") + + # Check that GEMINI_API_KEY was set from GOOGLE_API_KEY + assert os.environ.get("GEMINI_API_KEY") == test_api_key + + finally: + if original_google_api_key is not None: + os.environ["GOOGLE_API_KEY"] = original_google_api_key + else: + os.environ.pop("GOOGLE_API_KEY", None) + + if original_gemini_api_key is not None: + os.environ["GEMINI_API_KEY"] = original_gemini_api_key + else: + os.environ.pop("GEMINI_API_KEY", None) + + @pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"]) @pytest.mark.parametrize( "model",