diff --git a/src/crewai/llm.py b/src/crewai/llm.py index a80a9afd9..d14f6710c 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -204,6 +204,8 @@ LLM_CONTEXT_WINDOW_SIZES = { DEFAULT_CONTEXT_WINDOW_SIZE = 8192 CONTEXT_WINDOW_USAGE_RATIO = 0.75 +OPENROUTER_PROVIDER = "openrouter" + @contextmanager def suppress_warnings(): @@ -273,6 +275,38 @@ class LLM(BaseLLM): force_structured_output: bool = False, **kwargs, ): + """Initialize an LLM instance. + + Args: + model: The language model to use. + timeout: The request timeout in seconds. + temperature: The temperature to use for sampling. + top_p: The cumulative probability for top-p sampling. + n: The number of completions to generate. + stop: A list of strings to stop generation when encountered. + max_completion_tokens: The maximum number of tokens to generate. + max_tokens: Alias for max_completion_tokens. + presence_penalty: The presence penalty to use. + frequency_penalty: The frequency penalty to use. + logit_bias: The logit bias to use. + response_format: The format to return the response in. + seed: The random seed to use. + logprobs: Whether to return log probabilities. + top_logprobs: Whether to return the top log probabilities. + base_url: The base URL to use. + api_base: Alias for base_url. + api_version: The API version to use. + api_key: The API key to use. + callbacks: A list of callbacks to use. + reasoning_effort: The reasoning effort to use (e.g., "low", "medium", "high"). + stream: Whether to stream the response. + force_structured_output: When True and using OpenRouter provider, bypasses + response schema validation. Use with caution as it may lead to runtime + errors if the model doesn't actually support structured outputs. + Only use this if you're certain the model supports the expected format. + Defaults to False. + **kwargs: Additional parameters to pass to the LLM. + """ self.model = model self.timeout = timeout self.temperature = temperature @@ -993,16 +1027,34 @@ class LLM(BaseLLM): - "gemini/gemini-1.5-pro" yields "gemini" - If no slash is present, "openai" is assumed. """ + if self.response_format is None: + return + provider = self._get_custom_llm_provider() - if self.response_format is not None and not ( - supports_response_schema( - model=self.model, - custom_llm_provider=provider, - ) or (provider == "openrouter" and self.force_structured_output) - ): + provider_lower = provider.lower() if provider else "" + + # Check if we're bypassing validation for OpenRouter + is_openrouter_bypass = ( + provider_lower == OPENROUTER_PROVIDER.lower() and self.force_structured_output + ) + + if is_openrouter_bypass: + logging.warning( + f"Forcing structured output for OpenRouter model {self.model}. " + "Please ensure the model supports the expected response format." + ) + + # Check if the model supports response schema + is_schema_supported = supports_response_schema( + model=self.model, + custom_llm_provider=provider, + ) + + if not (is_schema_supported or is_openrouter_bypass): raise ValueError( f"The model {self.model} does not support response_format for provider '{provider}'. " - "Please remove response_format or use a supported model." + f"Please remove response_format, use a supported model, or if you're using an " + f"OpenRouter model that supports structured output, set force_structured_output=True." ) def supports_function_calling(self) -> bool: diff --git a/tests/llm_test.py b/tests/llm_test.py index adea8da58..d9d98933d 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -257,6 +257,10 @@ def test_validate_call_params_no_response_format(): def test_validate_call_params_openrouter_force_structured_output(): + """ + Test that force_structured_output parameter allows bypassing response schema + validation for OpenRouter models. + """ class DummyResponse(BaseModel): a: int @@ -282,6 +286,26 @@ def test_validate_call_params_openrouter_force_structured_output(): assert "does not support response_format" in str(excinfo.value) +def test_force_structured_output_bypasses_only_openrouter(): + """ + Test that force_structured_output parameter only bypasses validation for + OpenRouter models and not for other providers. + """ + class DummyResponse(BaseModel): + a: int + + # Test with non-OpenRouter provider and force_structured_output=True + with patch("crewai.llm.supports_response_schema", return_value=False): + llm = LLM( + model="otherprovider/model-name", + response_format=DummyResponse, + force_structured_output=True + ) + with pytest.raises(ValueError) as excinfo: + llm._validate_call_params() + assert "does not support response_format" in str(excinfo.value) + + @pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"]) @pytest.mark.parametrize( "model",