refactor: Improve model name validation

- Add better error messages with model name context
- Add type hints and docstrings
- Add constants for model name validation
- Organize tests into a class
- Add edge case tests

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-12 11:10:33 +00:00
parent 34faf609f4
commit bff64ae823
2 changed files with 35 additions and 11 deletions

View File

@@ -119,7 +119,16 @@ def suppress_warnings():
class LLM: class LLM:
def __init__( def __init__(
self, self,
model: str, model: Union[str, 'BaseLanguageModel'],
"""Initialize LLM instance.
Args:
model: The model identifier; should not start with 'models/'.
Examples: 'gemini/gemini-1.5-pro', 'anthropic/claude-3'
Raises:
ValueError: If the model name starts with 'models/'.
"""
timeout: Optional[Union[float, int]] = None, timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
@@ -142,10 +151,13 @@ class LLM:
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
**kwargs, **kwargs,
): ):
# Constants for model name validation
INVALID_MODEL_PREFIX = "models/"
# Validate model name # Validate model name
if isinstance(model, str) and model.startswith('models/'): if isinstance(model, str) and model.startswith(INVALID_MODEL_PREFIX):
raise ValueError( raise ValueError(
'Model name should not start with "models/". ' f'Invalid model name "{model}": Model names should not start with "{INVALID_MODEL_PREFIX}". '
'Use the provider prefix instead (e.g., "gemini/model-name").' 'Use the provider prefix instead (e.g., "gemini/model-name").'
) )

View File

@@ -252,15 +252,27 @@ def test_validate_call_params_no_response_format():
llm._validate_call_params() llm._validate_call_params()
def test_model_name_validation(): class TestModelNameValidation:
"""Test that model names with 'models/' prefix are rejected.""" """Tests for model name validation in LLM class."""
with pytest.raises(ValueError, match="should not start with \"models/\""):
LLM(model="models/gemini/gemini-1.5-pro") def test_models_prefix_rejection(self):
"""Test that model names with 'models/' prefix are rejected."""
with pytest.raises(ValueError, match="should not start with \"models/\""):
LLM(model="models/gemini/gemini-1.5-pro")
# Valid model names should work def test_valid_model_names(self):
LLM(model="gemini/gemini-1.5-pro") """Test that valid model names are accepted."""
LLM(model="anthropic/claude-3-opus-20240229-v1:0") LLM(model="gemini/gemini-1.5-pro")
LLM(model="openai/gpt-4") LLM(model="anthropic/claude-3-opus-20240229-v1:0")
LLM(model="openai/gpt-4")
LLM(model="openai/gpt-4 turbo") # Space in model name should work
def test_edge_cases(self):
"""Test edge cases for model name validation."""
with pytest.raises(ValueError):
LLM(model="") # Empty string
with pytest.raises(TypeError):
LLM(model=None) # None value
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])