mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 11:58:31 +00:00
refactor: enhance model validation and provider inference in LLM class (#3976)
* refactor: enhance model validation and provider inference in LLM class - Updated the model validation logic to support pattern matching for new models and "latest" versions, improving flexibility for various providers. - Refactored the `_validate_model_in_constants` method to first check hardcoded constants and then fall back to pattern matching. - Introduced `_matches_provider_pattern` to streamline provider-specific model checks. - Enhanced the `_infer_provider_from_model` method to utilize pattern matching for better provider inference. This refactor aims to improve the extensibility of the LLM class, allowing it to accommodate new models without requiring constant updates to the hardcoded lists. * feat: add new Anthropic model versions to constants - Introduced "claude-opus-4-5-20251101" and "claude-opus-4-5" to the AnthropicModels and ANTHROPIC_MODELS lists for enhanced model support. - Added "anthropic.claude-opus-4-5-20251101-v1:0" to BedrockModels and BEDROCK_MODELS to ensure compatibility with the latest model offerings. - Updated test cases to ensure proper environment variable handling for model validation, improving robustness in testing scenarios. * dont infer this way - dropped
This commit is contained in:
@@ -406,46 +406,100 @@ class LLM(BaseLLM):
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
"""Check if a model name matches provider-specific patterns.
|
||||
|
||||
This allows supporting models that aren't in the hardcoded constants list,
|
||||
including "latest" versions and new models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model matches the provider's naming pattern, False otherwise
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
|
||||
)
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return any(
|
||||
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
|
||||
)
|
||||
|
||||
if provider == "gemini" or provider == "google":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gemini-", "gemma-", "learnlm-"]
|
||||
)
|
||||
|
||||
if provider == "bedrock":
|
||||
return "." in model_lower
|
||||
|
||||
if provider == "azure":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
|
||||
"""Validate if a model name exists in the provider's constants.
|
||||
"""Validate if a model name exists in the provider's constants or matches provider patterns.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it falls back to pattern matching to support new models,
|
||||
"latest" versions, and models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to validate
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model exists in the provider's constants, False otherwise
|
||||
True if the model exists in constants or matches provider patterns, False otherwise
|
||||
"""
|
||||
if provider == "openai":
|
||||
return model in OPENAI_MODELS
|
||||
if provider == "openai" and model in OPENAI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return model in ANTHROPIC_MODELS
|
||||
if (
|
||||
provider == "anthropic" or provider == "claude"
|
||||
) and model in ANTHROPIC_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "gemini":
|
||||
return model in GEMINI_MODELS
|
||||
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "bedrock":
|
||||
return model in BEDROCK_MODELS
|
||||
if provider == "bedrock" and model in BEDROCK_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "azure":
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
return False
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
|
||||
@classmethod
|
||||
def _infer_provider_from_model(cls, model: str) -> str:
|
||||
"""Infer the provider from the model name.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it uses pattern matching to infer the provider from model name patterns.
|
||||
This allows supporting new models and "latest" versions without hardcoding.
|
||||
|
||||
Args:
|
||||
model: The model name without provider prefix
|
||||
|
||||
Returns:
|
||||
The inferred provider name, defaults to "openai"
|
||||
"""
|
||||
|
||||
if model in OPENAI_MODELS:
|
||||
return "openai"
|
||||
|
||||
@@ -1699,12 +1753,14 @@ class LLM(BaseLLM):
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=copy.deepcopy(self.logit_bias, memo)
|
||||
if self.logit_bias
|
||||
else None,
|
||||
response_format=copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None,
|
||||
logit_bias=(
|
||||
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
|
||||
),
|
||||
response_format=(
|
||||
copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None
|
||||
),
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
|
||||
@@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [
|
||||
|
||||
|
||||
AnthropicModels: TypeAlias = Literal[
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -452,6 +456,7 @@ BedrockModels: TypeAlias = Literal[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
@@ -524,6 +529,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
|
||||
@@ -243,7 +243,11 @@ def test_validate_call_params_not_supported():
|
||||
|
||||
# Patch supports_response_schema to simulate an unsupported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
response_format=DummyResponse,
|
||||
is_litellm=True,
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
@@ -702,13 +706,16 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
||||
|
||||
assert formatted == original_messages
|
||||
|
||||
|
||||
def test_native_provider_raises_error_when_supported_but_fails():
|
||||
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
|
||||
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
|
||||
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
|
||||
# Mock that provider exists but throws an error when instantiated
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.side_effect = ValueError("Native provider initialization failed")
|
||||
mock_provider.side_effect = ValueError(
|
||||
"Native provider initialization failed"
|
||||
)
|
||||
mock_get_native.return_value = mock_provider
|
||||
|
||||
with pytest.raises(ImportError) as excinfo:
|
||||
@@ -751,16 +758,16 @@ def test_prefixed_models_with_valid_constants_use_native_sdk():
|
||||
|
||||
|
||||
def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants."""
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns."""
|
||||
# Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM
|
||||
llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False)
|
||||
assert llm.is_litellm is True
|
||||
assert llm.model == "openai/gemini-2.5-flash"
|
||||
|
||||
# Test openai/ prefix with unknown future model → LiteLLM
|
||||
llm2 = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
# Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM
|
||||
llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False)
|
||||
assert llm2.is_litellm is True
|
||||
assert llm2.model == "openai/gpt-future-6"
|
||||
assert llm2.model == "openai/custom-finetune-model"
|
||||
|
||||
# Test anthropic/ prefix with non-Anthropic model → LiteLLM
|
||||
llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False)
|
||||
@@ -768,6 +775,21 @@ def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
assert llm3.model == "anthropic/gpt-4o"
|
||||
|
||||
|
||||
def test_prefixed_models_with_valid_patterns_use_native_sdk():
|
||||
"""Test that models matching provider patterns use native SDK even if not in constants."""
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
llm = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
assert llm.is_litellm is False
|
||||
assert llm.provider == "openai"
|
||||
assert llm.model == "gpt-future-6"
|
||||
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False)
|
||||
assert llm2.is_litellm is False
|
||||
assert llm2.provider == "anthropic"
|
||||
assert llm2.model == "claude-future-5"
|
||||
|
||||
|
||||
def test_prefixed_models_with_non_native_providers_use_litellm():
|
||||
"""Test that models with non-native provider prefixes always use LiteLLM."""
|
||||
# Test groq/ prefix (not a native provider) → LiteLLM
|
||||
@@ -821,19 +843,36 @@ def test_validate_model_in_constants():
|
||||
"""Test the _validate_model_in_constants method."""
|
||||
# OpenAI models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "openai") is True
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True
|
||||
assert LLM._validate_model_in_constants("o1-latest", "openai") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "openai") is False
|
||||
|
||||
# Anthropic models
|
||||
assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is False
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is True
|
||||
assert (
|
||||
LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True
|
||||
)
|
||||
assert LLM._validate_model_in_constants("unknown-model", "claude") is False
|
||||
|
||||
# Gemini models
|
||||
assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is False
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "gemini") is False
|
||||
|
||||
# Azure models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "azure") is True
|
||||
assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True
|
||||
|
||||
# Bedrock models
|
||||
assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True
|
||||
assert (
|
||||
LLM._validate_model_in_constants(
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0", "bedrock"
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock")
|
||||
is True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user