mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 17:18:29 +00:00
refactor: Improve model name validation
- Add better error messages with model name context - Add type hints and docstrings - Add constants for model name validation - Organize tests into a class - Add edge case tests Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -119,7 +119,16 @@ def suppress_warnings():
|
||||
class LLM:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
model: Union[str, 'BaseLanguageModel'],
|
||||
"""Initialize LLM instance.
|
||||
|
||||
Args:
|
||||
model: The model identifier; should not start with 'models/'.
|
||||
Examples: 'gemini/gemini-1.5-pro', 'anthropic/claude-3'
|
||||
|
||||
Raises:
|
||||
ValueError: If the model name starts with 'models/'.
|
||||
"""
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
@@ -142,10 +151,13 @@ class LLM:
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Constants for model name validation
|
||||
INVALID_MODEL_PREFIX = "models/"
|
||||
|
||||
# Validate model name
|
||||
if isinstance(model, str) and model.startswith('models/'):
|
||||
if isinstance(model, str) and model.startswith(INVALID_MODEL_PREFIX):
|
||||
raise ValueError(
|
||||
'Model name should not start with "models/". '
|
||||
f'Invalid model name "{model}": Model names should not start with "{INVALID_MODEL_PREFIX}". '
|
||||
'Use the provider prefix instead (e.g., "gemini/model-name").'
|
||||
)
|
||||
|
||||
|
||||
@@ -252,15 +252,27 @@ def test_validate_call_params_no_response_format():
|
||||
llm._validate_call_params()
|
||||
|
||||
|
||||
def test_model_name_validation():
|
||||
"""Test that model names with 'models/' prefix are rejected."""
|
||||
with pytest.raises(ValueError, match="should not start with \"models/\""):
|
||||
LLM(model="models/gemini/gemini-1.5-pro")
|
||||
class TestModelNameValidation:
|
||||
"""Tests for model name validation in LLM class."""
|
||||
|
||||
def test_models_prefix_rejection(self):
|
||||
"""Test that model names with 'models/' prefix are rejected."""
|
||||
with pytest.raises(ValueError, match="should not start with \"models/\""):
|
||||
LLM(model="models/gemini/gemini-1.5-pro")
|
||||
|
||||
# Valid model names should work
|
||||
LLM(model="gemini/gemini-1.5-pro")
|
||||
LLM(model="anthropic/claude-3-opus-20240229-v1:0")
|
||||
LLM(model="openai/gpt-4")
|
||||
def test_valid_model_names(self):
|
||||
"""Test that valid model names are accepted."""
|
||||
LLM(model="gemini/gemini-1.5-pro")
|
||||
LLM(model="anthropic/claude-3-opus-20240229-v1:0")
|
||||
LLM(model="openai/gpt-4")
|
||||
LLM(model="openai/gpt-4 turbo") # Space in model name should work
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Test edge cases for model name validation."""
|
||||
with pytest.raises(ValueError):
|
||||
LLM(model="") # Empty string
|
||||
with pytest.raises(TypeError):
|
||||
LLM(model=None) # None value
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
|
||||
Reference in New Issue
Block a user