Fix Gemini model integration issues (#2803)

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-05-10 10:51:46 +00:00
parent cb1a98cabf
commit ee308ed322
2 changed files with 154 additions and 2 deletions

View File

@@ -322,6 +322,42 @@ class LLM(BaseLLM):
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
def _is_gemini_model(self, model: str) -> bool:
"""Determine if the model is from Google Gemini provider.
Args:
model: The model identifier string.
Returns:
bool: True if the model is from Gemini, False otherwise.
"""
GEMINI_IDENTIFIERS = ("gemini", "gemma-")
return any(identifier in model.lower() for identifier in GEMINI_IDENTIFIERS)
def _normalize_gemini_model(self, model: str) -> str:
"""Normalize Gemini model name to the format expected by LiteLLM.
Handles formats like "models/gemini-pro" or "gemini-pro" and converts
them to "gemini/gemini-pro" format.
Args:
model: The model identifier string.
Returns:
str: Normalized model name.
"""
if model.startswith("gemini/"):
return model
if model.startswith("models/"):
model_name = model.split("/", 1)[1]
return f"gemini/{model_name}"
if self._is_gemini_model(model) and "/" not in model:
return f"gemini/{model}"
return model
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
@@ -343,9 +379,18 @@ class LLM(BaseLLM):
messages = [{"role": "user", "content": messages}]
formatted_messages = self._format_messages_for_provider(messages)
# --- 2) Prepare the parameters for the completion call
model = self.model
if self._is_gemini_model(model):
model = self._normalize_gemini_model(model)
# --- 2.1) Map GOOGLE_API_KEY to GEMINI_API_KEY if needed
if not os.environ.get("GEMINI_API_KEY") and os.environ.get("GOOGLE_API_KEY"):
os.environ["GEMINI_API_KEY"] = os.environ["GOOGLE_API_KEY"]
logging.info("Mapped GOOGLE_API_KEY to GEMINI_API_KEY for Gemini model")
# --- 3) Prepare the parameters for the completion call
params = {
"model": self.model,
"model": model,
"messages": formatted_messages,
"timeout": self.timeout,
"temperature": self.temperature,