diff --git a/src/crewai/llm.py b/src/crewai/llm.py index e574435aa..c23df15dc 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -92,9 +92,43 @@ def suppress_warnings(): class LLM: + """ + A wrapper class for language model interactions using litellm. + + This class provides a unified interface for interacting with various language models + through litellm. It handles model configuration, context window sizing, and callback + management. + + Args: + model (str): The identifier for the language model to use. Must be a valid model ID + with a provider prefix (e.g., 'openai/gpt-4'). Cannot be a numeric value without + a provider prefix. + timeout (Optional[Union[float, int]]): The timeout for API calls in seconds. + temperature (Optional[float]): Controls randomness in the model's output. + top_p (Optional[float]): Controls diversity via nucleus sampling. + n (Optional[int]): Number of completions to generate. + stop (Optional[Union[str, List[str]]]): Sequences where the model should stop generating. + max_completion_tokens (Optional[int]): Maximum number of tokens to generate. + max_tokens (Optional[int]): Alias for max_completion_tokens. + presence_penalty (Optional[float]): Penalizes repeated tokens. + frequency_penalty (Optional[float]): Penalizes frequent tokens. + logit_bias (Optional[Dict[int, float]]): Modifies likelihood of specific tokens. + response_format (Optional[Dict[str, Any]]): Specifies the format for the model's response. + seed (Optional[int]): Seed for deterministic outputs. + logprobs (Optional[bool]): Whether to return log probabilities. + top_logprobs (Optional[int]): Number of most likely tokens to return probabilities for. + base_url (Optional[str]): Base URL for API calls. + api_version (Optional[str]): API version to use. + api_key (Optional[str]): API key for authentication. + callbacks (List[Any]): List of callback functions. + **kwargs: Additional keyword arguments to pass to the model. + + Raises: + ValueError: If the model ID is empty, whitespace, or a numeric value without a provider prefix. + """ def __init__( self, - model: str, + model: Union[str, Any], timeout: Optional[Union[float, int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, @@ -115,6 +149,19 @@ class LLM: callbacks: List[Any] = [], **kwargs, ): + # Validate model ID to ensure it's not empty, whitespace, or a numeric value without a provider prefix + if not model or (isinstance(model, str) and model.strip() == ""): + raise ValueError( + f"Invalid model ID: '{model}'. Model ID cannot be empty or whitespace. " + "Please specify a valid model ID with a provider prefix, e.g., 'openai/gpt-4'." + ) + + if isinstance(model, (int, float)) or (isinstance(model, str) and model.isdigit()): + raise ValueError( + f"Invalid model ID: {model}. Model ID cannot be a numeric value without a provider prefix. " + "Please specify a valid model ID with a provider prefix, e.g., 'openai/gpt-4'." + ) + self.model = model self.timeout = timeout self.temperature = temperature @@ -137,13 +184,6 @@ class LLM: self.context_window_size = 0 self.kwargs = kwargs - # Validate model ID to ensure it's not a numeric value without a provider prefix - if isinstance(self.model, (int, float)) or (isinstance(self.model, str) and self.model.isdigit()): - raise ValueError( - f"Invalid model ID: {self.model}. Model ID cannot be a numeric value without a provider prefix. " - "Please specify a valid model ID with a provider prefix, e.g., 'openai/gpt-4'." - ) - litellm.drop_params = True litellm.set_verbose = False self.set_callbacks(callbacks) @@ -215,8 +255,10 @@ class LLM: self.context_window_size = int( DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO ) + # Ensure model is a string before calling startswith + model_str = str(self.model) if not isinstance(self.model, str) else self.model for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): - if self.model.startswith(key): + if model_str.startswith(key): self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) return self.context_window_size diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index 73f215bc3..870d4a57b 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -1,21 +1,42 @@ import pytest - from crewai.llm import LLM -def test_numeric_model_id_validation(): - # Test with integer model ID - with pytest.raises(ValueError, match="Invalid model ID: 3420. Model ID cannot be a numeric value without a provider prefix."): - LLM(model=3420) - - # Test with string numeric model ID - with pytest.raises(ValueError, match="Invalid model ID: 3420. Model ID cannot be a numeric value without a provider prefix."): - LLM(model="3420") - - # Test with valid model ID - llm = LLM(model="openai/gpt-4") - assert llm.model == "openai/gpt-4" - - # Test with valid model ID that contains numbers - llm = LLM(model="gpt-3.5-turbo") - assert llm.model == "gpt-3.5-turbo" +@pytest.mark.parametrize( + "invalid_model,error_message", + [ + (3420, "Invalid model ID: 3420. Model ID cannot be a numeric value without a provider prefix."), + ("3420", "Invalid model ID: 3420. Model ID cannot be a numeric value without a provider prefix."), + (3.14, "Invalid model ID: 3.14. Model ID cannot be a numeric value without a provider prefix."), + ], +) +def test_invalid_numeric_model_ids(invalid_model, error_message): + """Test that numeric model IDs are rejected.""" + with pytest.raises(ValueError, match=error_message): + LLM(model=invalid_model) + + +@pytest.mark.parametrize( + "valid_model", + [ + "openai/gpt-4", + "gpt-3.5-turbo", + "anthropic/claude-2", + ], +) +def test_valid_model_ids(valid_model): + """Test that valid model IDs are accepted.""" + llm = LLM(model=valid_model) + assert llm.model == valid_model + + +def test_empty_model_id(): + """Test that empty model IDs are rejected.""" + with pytest.raises(ValueError, match="Invalid model ID: ''. Model ID cannot be empty or whitespace."): + LLM(model="") + + +def test_whitespace_model_id(): + """Test that whitespace model IDs are rejected.""" + with pytest.raises(ValueError, match="Invalid model ID: ' '. Model ID cannot be empty or whitespace."): + LLM(model=" ")