mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-21 22:08:21 +00:00
Fix Gemini model integration issues (#2803)
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -220,6 +220,37 @@ def test_get_custom_llm_provider_gemini():
|
||||
assert llm._get_custom_llm_provider() == "gemini"
|
||||
|
||||
|
||||
def test_is_gemini_model():
|
||||
"""Test the _is_gemini_model method with various model names."""
|
||||
llm = LLM(model="gpt-4") # Model doesn't matter for this test
|
||||
|
||||
assert llm._is_gemini_model("gemini-pro") == True
|
||||
assert llm._is_gemini_model("gemini/gemini-1.5-pro") == True
|
||||
assert llm._is_gemini_model("models/gemini-pro") == True
|
||||
assert llm._is_gemini_model("gemma-7b") == True
|
||||
|
||||
# Should not identify as Gemini models
|
||||
assert llm._is_gemini_model("gpt-4") == False
|
||||
assert llm._is_gemini_model("claude-3") == False
|
||||
assert llm._is_gemini_model("mistral-7b") == False
|
||||
|
||||
|
||||
def test_normalize_gemini_model():
|
||||
"""Test the _normalize_gemini_model method with various model formats."""
|
||||
llm = LLM(model="gpt-4") # Model doesn't matter for this test
|
||||
|
||||
assert llm._normalize_gemini_model("gemini/gemini-1.5-pro") == "gemini/gemini-1.5-pro"
|
||||
|
||||
assert llm._normalize_gemini_model("models/gemini-pro") == "gemini/gemini-pro"
|
||||
assert llm._normalize_gemini_model("models/gemini-1.5-flash") == "gemini/gemini-1.5-flash"
|
||||
|
||||
assert llm._normalize_gemini_model("gemini-pro") == "gemini/gemini-pro"
|
||||
assert llm._normalize_gemini_model("gemini-1.5-flash") == "gemini/gemini-1.5-flash"
|
||||
|
||||
assert llm._normalize_gemini_model("gpt-4") == "gpt-4"
|
||||
assert llm._normalize_gemini_model("claude-3") == "claude-3"
|
||||
|
||||
|
||||
def test_get_custom_llm_provider_openai():
|
||||
llm = LLM(model="gpt-4")
|
||||
assert llm._get_custom_llm_provider() == None
|
||||
@@ -274,6 +305,82 @@ def test_gemini_models(model):
|
||||
assert "Paris" in result
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"models/gemini-pro", # Format from issue #2803
|
||||
"gemini-pro", # Format without provider prefix
|
||||
],
|
||||
)
|
||||
def test_gemini_model_normalization(model):
|
||||
"""Test that different Gemini model formats are normalized correctly."""
|
||||
llm = LLM(model=model)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
# Create mocks for response structure
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Paris"
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
# Set up the mocked completion to return the mock response
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm.call("What is the capital of France?")
|
||||
|
||||
# Check that the model was normalized correctly in the call to litellm
|
||||
args, kwargs = mock_completion.call_args
|
||||
assert kwargs["model"].startswith("gemini/")
|
||||
assert "gemini-pro" in kwargs["model"]
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||
def test_gemini_api_key_mapping():
|
||||
"""Test that GOOGLE_API_KEY is mapped to GEMINI_API_KEY for Gemini models."""
|
||||
original_google_api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
original_gemini_api_key = os.environ.get("GEMINI_API_KEY")
|
||||
|
||||
try:
|
||||
# Set up test environment
|
||||
test_api_key = "test_google_api_key"
|
||||
os.environ["GOOGLE_API_KEY"] = test_api_key
|
||||
if "GEMINI_API_KEY" in os.environ:
|
||||
del os.environ["GEMINI_API_KEY"]
|
||||
|
||||
llm = LLM(model="gemini-pro")
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
# Create mocks for response structure
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Paris"
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
# Set up the mocked completion to return the mock response
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm.call("What is the capital of France?")
|
||||
|
||||
# Check that GEMINI_API_KEY was set from GOOGLE_API_KEY
|
||||
assert os.environ.get("GEMINI_API_KEY") == test_api_key
|
||||
|
||||
finally:
|
||||
if original_google_api_key is not None:
|
||||
os.environ["GOOGLE_API_KEY"] = original_google_api_key
|
||||
else:
|
||||
os.environ.pop("GOOGLE_API_KEY", None)
|
||||
|
||||
if original_gemini_api_key is not None:
|
||||
os.environ["GEMINI_API_KEY"] = original_gemini_api_key
|
||||
else:
|
||||
os.environ.pop("GEMINI_API_KEY", None)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
|
||||
Reference in New Issue
Block a user