diff --git a/src/crewai/llm.py b/src/crewai/llm.py index d62694f14..47295303b 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -119,7 +119,16 @@ def suppress_warnings(): class LLM: def __init__( self, - model: str, + model: Union[str, 'BaseLanguageModel'], + """Initialize LLM instance. + + Args: + model: The model identifier; should not start with 'models/'. + Examples: 'gemini/gemini-1.5-pro', 'anthropic/claude-3' + + Raises: + ValueError: If the model name starts with 'models/'. + """ timeout: Optional[Union[float, int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, @@ -142,10 +151,13 @@ class LLM: reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, **kwargs, ): + # Constants for model name validation + INVALID_MODEL_PREFIX = "models/" + # Validate model name - if isinstance(model, str) and model.startswith('models/'): + if isinstance(model, str) and model.startswith(INVALID_MODEL_PREFIX): raise ValueError( - 'Model name should not start with "models/". ' + f'Invalid model name "{model}": Model names should not start with "{INVALID_MODEL_PREFIX}". ' 'Use the provider prefix instead (e.g., "gemini/model-name").' ) diff --git a/tests/llm_test.py b/tests/llm_test.py index 876e98e0e..44d8d6a52 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -252,15 +252,27 @@ def test_validate_call_params_no_response_format(): llm._validate_call_params() -def test_model_name_validation(): - """Test that model names with 'models/' prefix are rejected.""" - with pytest.raises(ValueError, match="should not start with \"models/\""): - LLM(model="models/gemini/gemini-1.5-pro") +class TestModelNameValidation: + """Tests for model name validation in LLM class.""" + + def test_models_prefix_rejection(self): + """Test that model names with 'models/' prefix are rejected.""" + with pytest.raises(ValueError, match="should not start with \"models/\""): + LLM(model="models/gemini/gemini-1.5-pro") - # Valid model names should work - LLM(model="gemini/gemini-1.5-pro") - LLM(model="anthropic/claude-3-opus-20240229-v1:0") - LLM(model="openai/gpt-4") + def test_valid_model_names(self): + """Test that valid model names are accepted.""" + LLM(model="gemini/gemini-1.5-pro") + LLM(model="anthropic/claude-3-opus-20240229-v1:0") + LLM(model="openai/gpt-4") + LLM(model="openai/gpt-4 turbo") # Space in model name should work + + def test_edge_cases(self): + """Test edge cases for model name validation.""" + with pytest.raises(ValueError): + LLM(model="") # Empty string + with pytest.raises(TypeError): + LLM(model=None) # None value @pytest.mark.vcr(filter_headers=["authorization"])