Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
5b03d0e0db Address PR review feedback: move constants to class level, add error handling, enhance logging
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-10 11:00:51 +00:00
Devin AI
ee308ed322 Fix Gemini model integration issues (#2803)
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-05-10 10:51:46 +00:00
2 changed files with 175 additions and 4 deletions

View File

@@ -246,6 +246,9 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
GEMINI_IDENTIFIERS = ("gemini", "gemma-")
def __init__(
self,
model: str,
@@ -319,8 +322,55 @@ class LLM(BaseLLM):
Returns:
bool: True if the model is from Anthropic, False otherwise.
"""
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
if not isinstance(model, str):
return False
return any(prefix in model.lower() for prefix in self.ANTHROPIC_PREFIXES)
def _is_gemini_model(self, model: str) -> bool:
"""Determine if the model is from Google Gemini provider.
Args:
model: The model identifier string.
Returns:
bool: True if the model is from Gemini, False otherwise.
"""
if not isinstance(model, str):
return False
return any(identifier in model.lower() for identifier in self.GEMINI_IDENTIFIERS)
def _normalize_gemini_model(self, model: str) -> str:
"""Normalize Gemini model name to the format expected by LiteLLM.
Handles formats like "models/gemini-pro" or "gemini-pro" and converts
them to "gemini/gemini-pro" format.
Args:
model: The model identifier string.
Returns:
str: Normalized model name.
Raises:
ValueError: If model is not a string or is empty.
"""
if not isinstance(model, str):
raise ValueError(f"Model must be a string, got {type(model)}")
if not model.strip():
raise ValueError("Model name cannot be empty")
if model.startswith("gemini/"):
return model
if model.startswith("models/"):
model_name = model.split("/", 1)[1]
return f"gemini/{model_name}"
if self._is_gemini_model(model) and "/" not in model:
return f"gemini/{model}"
return model
def _prepare_completion_params(
self,
@@ -343,9 +393,23 @@ class LLM(BaseLLM):
messages = [{"role": "user", "content": messages}]
formatted_messages = self._format_messages_for_provider(messages)
# --- 2) Prepare the parameters for the completion call
model = self.model
if self._is_gemini_model(model):
try:
model = self._normalize_gemini_model(model)
logging.info(f"Normalized Gemini model name from '{self.model}' to '{model}'")
# --- 2.1) Map GOOGLE_API_KEY to GEMINI_API_KEY if needed
if not os.environ.get("GEMINI_API_KEY") and os.environ.get("GOOGLE_API_KEY"):
os.environ["GEMINI_API_KEY"] = os.environ["GOOGLE_API_KEY"]
logging.info("Mapped GOOGLE_API_KEY to GEMINI_API_KEY for Gemini model")
except ValueError as e:
logging.error(f"Error normalizing Gemini model: {str(e)}")
model = self.model
# --- 3) Prepare the parameters for the completion call
params = {
"model": self.model,
"model": model,
"messages": formatted_messages,
"timeout": self.timeout,
"temperature": self.temperature,

View File

@@ -220,6 +220,37 @@ def test_get_custom_llm_provider_gemini():
assert llm._get_custom_llm_provider() == "gemini"
def test_is_gemini_model():
"""Test the _is_gemini_model method with various model names."""
llm = LLM(model="gpt-4") # Model doesn't matter for this test
assert llm._is_gemini_model("gemini-pro") == True
assert llm._is_gemini_model("gemini/gemini-1.5-pro") == True
assert llm._is_gemini_model("models/gemini-pro") == True
assert llm._is_gemini_model("gemma-7b") == True
# Should not identify as Gemini models
assert llm._is_gemini_model("gpt-4") == False
assert llm._is_gemini_model("claude-3") == False
assert llm._is_gemini_model("mistral-7b") == False
def test_normalize_gemini_model():
"""Test the _normalize_gemini_model method with various model formats."""
llm = LLM(model="gpt-4") # Model doesn't matter for this test
assert llm._normalize_gemini_model("gemini/gemini-1.5-pro") == "gemini/gemini-1.5-pro"
assert llm._normalize_gemini_model("models/gemini-pro") == "gemini/gemini-pro"
assert llm._normalize_gemini_model("models/gemini-1.5-flash") == "gemini/gemini-1.5-flash"
assert llm._normalize_gemini_model("gemini-pro") == "gemini/gemini-pro"
assert llm._normalize_gemini_model("gemini-1.5-flash") == "gemini/gemini-1.5-flash"
assert llm._normalize_gemini_model("gpt-4") == "gpt-4"
assert llm._normalize_gemini_model("claude-3") == "claude-3"
def test_get_custom_llm_provider_openai():
llm = LLM(model="gpt-4")
assert llm._get_custom_llm_provider() == None
@@ -274,6 +305,82 @@ def test_gemini_models(model):
assert "Paris" in result
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
@pytest.mark.parametrize(
"model",
[
"models/gemini-pro", # Format from issue #2803
"gemini-pro", # Format without provider prefix
],
)
def test_gemini_model_normalization(model):
"""Test that different Gemini model formats are normalized correctly."""
llm = LLM(model=model)
with patch("litellm.completion") as mock_completion:
# Create mocks for response structure
mock_message = MagicMock()
mock_message.content = "Paris"
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
# Set up the mocked completion to return the mock response
mock_completion.return_value = mock_response
llm.call("What is the capital of France?")
# Check that the model was normalized correctly in the call to litellm
args, kwargs = mock_completion.call_args
assert kwargs["model"].startswith("gemini/")
assert "gemini-pro" in kwargs["model"]
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
def test_gemini_api_key_mapping():
"""Test that GOOGLE_API_KEY is mapped to GEMINI_API_KEY for Gemini models."""
original_google_api_key = os.environ.get("GOOGLE_API_KEY")
original_gemini_api_key = os.environ.get("GEMINI_API_KEY")
try:
# Set up test environment
test_api_key = "test_google_api_key"
os.environ["GOOGLE_API_KEY"] = test_api_key
if "GEMINI_API_KEY" in os.environ:
del os.environ["GEMINI_API_KEY"]
llm = LLM(model="gemini-pro")
with patch("litellm.completion") as mock_completion:
# Create mocks for response structure
mock_message = MagicMock()
mock_message.content = "Paris"
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
# Set up the mocked completion to return the mock response
mock_completion.return_value = mock_response
llm.call("What is the capital of France?")
# Check that GEMINI_API_KEY was set from GOOGLE_API_KEY
assert os.environ.get("GEMINI_API_KEY") == test_api_key
finally:
if original_google_api_key is not None:
os.environ["GOOGLE_API_KEY"] = original_google_api_key
else:
os.environ.pop("GOOGLE_API_KEY", None)
if original_gemini_api_key is not None:
os.environ["GEMINI_API_KEY"] = original_gemini_api_key
else:
os.environ.pop("GEMINI_API_KEY", None)
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
@pytest.mark.parametrize(
"model",