Address PR feedback: Improve documentation, refactor validation logic, add logging, and expand test coverage

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-04-30 12:02:13 +00:00
parent 9b0fbd24ee
commit d636593359
2 changed files with 83 additions and 7 deletions

View File

@@ -204,6 +204,8 @@ LLM_CONTEXT_WINDOW_SIZES = {
DEFAULT_CONTEXT_WINDOW_SIZE = 8192
CONTEXT_WINDOW_USAGE_RATIO = 0.75
OPENROUTER_PROVIDER = "openrouter"
@contextmanager
def suppress_warnings():
@@ -273,6 +275,38 @@ class LLM(BaseLLM):
force_structured_output: bool = False,
**kwargs,
):
"""Initialize an LLM instance.
Args:
model: The language model to use.
timeout: The request timeout in seconds.
temperature: The temperature to use for sampling.
top_p: The cumulative probability for top-p sampling.
n: The number of completions to generate.
stop: A list of strings to stop generation when encountered.
max_completion_tokens: The maximum number of tokens to generate.
max_tokens: Alias for max_completion_tokens.
presence_penalty: The presence penalty to use.
frequency_penalty: The frequency penalty to use.
logit_bias: The logit bias to use.
response_format: The format to return the response in.
seed: The random seed to use.
logprobs: Whether to return log probabilities.
top_logprobs: Whether to return the top log probabilities.
base_url: The base URL to use.
api_base: Alias for base_url.
api_version: The API version to use.
api_key: The API key to use.
callbacks: A list of callbacks to use.
reasoning_effort: The reasoning effort to use (e.g., "low", "medium", "high").
stream: Whether to stream the response.
force_structured_output: When True and using OpenRouter provider, bypasses
response schema validation. Use with caution as it may lead to runtime
errors if the model doesn't actually support structured outputs.
Only use this if you're certain the model supports the expected format.
Defaults to False.
**kwargs: Additional parameters to pass to the LLM.
"""
self.model = model
self.timeout = timeout
self.temperature = temperature
@@ -993,16 +1027,34 @@ class LLM(BaseLLM):
- "gemini/gemini-1.5-pro" yields "gemini"
- If no slash is present, "openai" is assumed.
"""
if self.response_format is None:
return
provider = self._get_custom_llm_provider()
if self.response_format is not None and not (
supports_response_schema(
model=self.model,
custom_llm_provider=provider,
) or (provider == "openrouter" and self.force_structured_output)
):
provider_lower = provider.lower() if provider else ""
# Check if we're bypassing validation for OpenRouter
is_openrouter_bypass = (
provider_lower == OPENROUTER_PROVIDER.lower() and self.force_structured_output
)
if is_openrouter_bypass:
logging.warning(
f"Forcing structured output for OpenRouter model {self.model}. "
"Please ensure the model supports the expected response format."
)
# Check if the model supports response schema
is_schema_supported = supports_response_schema(
model=self.model,
custom_llm_provider=provider,
)
if not (is_schema_supported or is_openrouter_bypass):
raise ValueError(
f"The model {self.model} does not support response_format for provider '{provider}'. "
"Please remove response_format or use a supported model."
f"Please remove response_format, use a supported model, or if you're using an "
f"OpenRouter model that supports structured output, set force_structured_output=True."
)
def supports_function_calling(self) -> bool:

View File

@@ -257,6 +257,10 @@ def test_validate_call_params_no_response_format():
def test_validate_call_params_openrouter_force_structured_output():
"""
Test that force_structured_output parameter allows bypassing response schema
validation for OpenRouter models.
"""
class DummyResponse(BaseModel):
a: int
@@ -282,6 +286,26 @@ def test_validate_call_params_openrouter_force_structured_output():
assert "does not support response_format" in str(excinfo.value)
def test_force_structured_output_bypasses_only_openrouter():
"""
Test that force_structured_output parameter only bypasses validation for
OpenRouter models and not for other providers.
"""
class DummyResponse(BaseModel):
a: int
# Test with non-OpenRouter provider and force_structured_output=True
with patch("crewai.llm.supports_response_schema", return_value=False):
llm = LLM(
model="otherprovider/model-name",
response_format=DummyResponse,
force_structured_output=True
)
with pytest.raises(ValueError) as excinfo:
llm._validate_call_params()
assert "does not support response_format" in str(excinfo.value)
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
@pytest.mark.parametrize(
"model",