diff --git a/src/crewai/llm.py b/src/crewai/llm.py index bbf8e35d9..4a932979e 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -5,15 +5,17 @@ import sys import threading import warnings from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Literal, Optional, Type, Union, cast from dotenv import load_dotenv +from pydantic import BaseModel with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) import litellm from litellm import Choices, get_supported_openai_params from litellm.types.utils import ModelResponse + from litellm.utils import supports_response_schema from crewai.utilities.exceptions.context_window_exceeding_exception import ( @@ -128,7 +130,7 @@ class LLM: presence_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None, logit_bias: Optional[Dict[int, float]] = None, - response_format: Optional[Dict[str, Any]] = None, + response_format: Optional[Type[BaseModel]] = None, seed: Optional[int] = None, logprobs: Optional[int] = None, top_logprobs: Optional[int] = None, @@ -211,6 +213,9 @@ class LLM: response = llm.call(messages) print(response) """ + # Validate parameters before proceeding with the call. + self._validate_call_params() + if isinstance(messages, str): messages = [{"role": "user", "content": messages}] @@ -309,6 +314,36 @@ class LLM: logging.error(f"LiteLLM call failed: {str(e)}") raise + def _get_custom_llm_provider(self) -> str: + """ + Derives the custom_llm_provider from the model string. + - For example, if the model is "openrouter/deepseek/deepseek-chat", returns "openrouter". + - If the model is "gemini/gemini-1.5-pro", returns "gemini". + - If there is no '/', defaults to "openai". + """ + if "/" in self.model: + return self.model.split("/")[0] + return "openai" + + def _validate_call_params(self) -> None: + """ + Validate parameters before making a call. Currently this only checks if + a response_format is provided and whether the model supports it. + The custom_llm_provider is dynamically determined from the model: + - E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter" + - "gemini/gemini-1.5-pro" yields "gemini" + - If no slash is present, "openai" is assumed. + """ + provider = self._get_custom_llm_provider() + if self.response_format is not None and not supports_response_schema( + model=self.model, + custom_llm_provider=provider, + ): + raise ValueError( + f"The model {self.model} does not support response_format for provider '{provider}'. " + "Please remove response_format or use a supported model." + ) + def supports_function_calling(self) -> bool: try: params = get_supported_openai_params(model=self.model) diff --git a/tests/llm_test.py b/tests/llm_test.py index 8db8726d0..15c8e7c51 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -2,6 +2,7 @@ from time import sleep from unittest.mock import MagicMock, patch import pytest +from pydantic import BaseModel from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.llm import LLM @@ -202,3 +203,49 @@ def test_llm_passes_additional_params(): # Check the result from llm.call assert result == "Test response" + + +def test_get_custom_llm_provider_openrouter(): + llm = LLM(model="openrouter/deepseek/deepseek-chat") + assert llm._get_custom_llm_provider() == "openrouter" + + +def test_get_custom_llm_provider_gemini(): + llm = LLM(model="gemini/gemini-1.5-pro") + assert llm._get_custom_llm_provider() == "gemini" + + +def test_get_custom_llm_provider_openai(): + llm = LLM(model="gpt-4") + assert llm._get_custom_llm_provider() == "openai" + + +def test_validate_call_params_supported(): + class DummyResponse(BaseModel): + a: int + + # Patch supports_response_schema to simulate a supported model. + with patch("crewai.llm.supports_response_schema", return_value=True): + llm = LLM( + model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse + ) + # Should not raise any error. + llm._validate_call_params() + + +def test_validate_call_params_not_supported(): + class DummyResponse(BaseModel): + a: int + + # Patch supports_response_schema to simulate an unsupported model. + with patch("crewai.llm.supports_response_schema", return_value=False): + llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse) + with pytest.raises(ValueError) as excinfo: + llm._validate_call_params() + assert "does not support response_format" in str(excinfo.value) + + +def test_validate_call_params_no_response_format(): + # When no response_format is provided, no validation error should occur. + llm = LLM(model="gemini/gemini-1.5-pro", response_format=None) + llm._validate_call_params()