mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
Fix langchain-google-genai integration by stripping 'models/' prefix
Fixes #3702 When using langchain-google-genai with CrewAI, the model name would include a 'models/' prefix (e.g., 'models/gemini/gemini-pro') which is added by Google's API internally. However, LiteLLM does not recognize this format, causing a 'LLM Provider NOT provided' error. This fix strips the 'models/' prefix from model names when extracting them from langchain model objects, ensuring compatibility with LiteLLM while maintaining backward compatibility with models that don't have this prefix. Changes: - Modified create_llm() in llm_utils.py to strip 'models/' prefix - Added comprehensive tests covering various scenarios: - Model with 'models/' prefix in model attribute - Model with 'models/' prefix in model_name attribute - Model without prefix (no change) - Case-sensitive prefix handling (only lowercase 'models/' is stripped) Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -41,6 +41,10 @@ def create_llm(
|
|||||||
or getattr(llm_value, "deployment_name", None)
|
or getattr(llm_value, "deployment_name", None)
|
||||||
or str(llm_value)
|
or str(llm_value)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(model, str) and model.startswith("models/"):
|
||||||
|
model = model[len("models/"):]
|
||||||
|
|
||||||
temperature: float | None = getattr(llm_value, "temperature", None)
|
temperature: float | None = getattr(llm_value, "temperature", None)
|
||||||
max_tokens: int | None = getattr(llm_value, "max_tokens", None)
|
max_tokens: int | None = getattr(llm_value, "max_tokens", None)
|
||||||
logprobs: int | None = getattr(llm_value, "logprobs", None)
|
logprobs: int | None = getattr(llm_value, "logprobs", None)
|
||||||
|
|||||||
@@ -94,3 +94,49 @@ def test_create_llm_with_invalid_type():
|
|||||||
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
|
with pytest.raises(BadRequestError, match="LLM Provider NOT provided"):
|
||||||
llm = create_llm(llm_value=42)
|
llm = create_llm(llm_value=42)
|
||||||
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
|
llm.call(messages=[{"role": "user", "content": "Hello, world!"}])
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_strips_models_prefix_from_model_attribute():
|
||||||
|
"""Test that 'models/' prefix is stripped from langchain model names."""
|
||||||
|
class LangChainLikeModel:
|
||||||
|
model = "models/gemini/gemini-pro"
|
||||||
|
temperature = 0.7
|
||||||
|
|
||||||
|
obj = LangChainLikeModel()
|
||||||
|
llm = create_llm(llm_value=obj)
|
||||||
|
assert isinstance(llm, LLM)
|
||||||
|
assert llm.model == "gemini/gemini-pro" # 'models/' prefix should be stripped
|
||||||
|
assert llm.temperature == 0.7
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_strips_models_prefix_from_model_name_attribute():
|
||||||
|
"""Test that 'models/' prefix is stripped from model_name attribute."""
|
||||||
|
class LangChainLikeModelWithModelName:
|
||||||
|
model_name = "models/gemini/gemini-2.0-flash"
|
||||||
|
|
||||||
|
obj = LangChainLikeModelWithModelName()
|
||||||
|
llm = create_llm(llm_value=obj)
|
||||||
|
assert isinstance(llm, LLM)
|
||||||
|
assert llm.model == "gemini/gemini-2.0-flash" # 'models/' prefix should be stripped
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_handles_model_without_prefix():
|
||||||
|
"""Test that models without 'models/' prefix are handled correctly."""
|
||||||
|
class RegularModel:
|
||||||
|
model = "gemini/gemini-pro"
|
||||||
|
|
||||||
|
obj = RegularModel()
|
||||||
|
llm = create_llm(llm_value=obj)
|
||||||
|
assert isinstance(llm, LLM)
|
||||||
|
assert llm.model == "gemini/gemini-pro" # No change when prefix not present
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_llm_strips_models_prefix_case_sensitive():
|
||||||
|
"""Test that only lowercase 'models/' prefix is stripped."""
|
||||||
|
class UpperCaseModel:
|
||||||
|
model = "Models/gemini/gemini-pro" # Uppercase M
|
||||||
|
|
||||||
|
obj = UpperCaseModel()
|
||||||
|
llm = create_llm(llm_value=obj)
|
||||||
|
assert isinstance(llm, LLM)
|
||||||
|
assert llm.model == "Models/gemini/gemini-pro"
|
||||||
|
|||||||
Reference in New Issue
Block a user