mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Address PR feedback: Improve documentation, refactor validation logic, add logging, and expand test coverage
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -204,6 +204,8 @@ LLM_CONTEXT_WINDOW_SIZES = {
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE = 8192
|
||||
CONTEXT_WINDOW_USAGE_RATIO = 0.75
|
||||
|
||||
OPENROUTER_PROVIDER = "openrouter"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
@@ -273,6 +275,38 @@ class LLM(BaseLLM):
|
||||
force_structured_output: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize an LLM instance.
|
||||
|
||||
Args:
|
||||
model: The language model to use.
|
||||
timeout: The request timeout in seconds.
|
||||
temperature: The temperature to use for sampling.
|
||||
top_p: The cumulative probability for top-p sampling.
|
||||
n: The number of completions to generate.
|
||||
stop: A list of strings to stop generation when encountered.
|
||||
max_completion_tokens: The maximum number of tokens to generate.
|
||||
max_tokens: Alias for max_completion_tokens.
|
||||
presence_penalty: The presence penalty to use.
|
||||
frequency_penalty: The frequency penalty to use.
|
||||
logit_bias: The logit bias to use.
|
||||
response_format: The format to return the response in.
|
||||
seed: The random seed to use.
|
||||
logprobs: Whether to return log probabilities.
|
||||
top_logprobs: Whether to return the top log probabilities.
|
||||
base_url: The base URL to use.
|
||||
api_base: Alias for base_url.
|
||||
api_version: The API version to use.
|
||||
api_key: The API key to use.
|
||||
callbacks: A list of callbacks to use.
|
||||
reasoning_effort: The reasoning effort to use (e.g., "low", "medium", "high").
|
||||
stream: Whether to stream the response.
|
||||
force_structured_output: When True and using OpenRouter provider, bypasses
|
||||
response schema validation. Use with caution as it may lead to runtime
|
||||
errors if the model doesn't actually support structured outputs.
|
||||
Only use this if you're certain the model supports the expected format.
|
||||
Defaults to False.
|
||||
**kwargs: Additional parameters to pass to the LLM.
|
||||
"""
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
@@ -993,16 +1027,34 @@ class LLM(BaseLLM):
|
||||
- "gemini/gemini-1.5-pro" yields "gemini"
|
||||
- If no slash is present, "openai" is assumed.
|
||||
"""
|
||||
if self.response_format is None:
|
||||
return
|
||||
|
||||
provider = self._get_custom_llm_provider()
|
||||
if self.response_format is not None and not (
|
||||
supports_response_schema(
|
||||
model=self.model,
|
||||
custom_llm_provider=provider,
|
||||
) or (provider == "openrouter" and self.force_structured_output)
|
||||
):
|
||||
provider_lower = provider.lower() if provider else ""
|
||||
|
||||
# Check if we're bypassing validation for OpenRouter
|
||||
is_openrouter_bypass = (
|
||||
provider_lower == OPENROUTER_PROVIDER.lower() and self.force_structured_output
|
||||
)
|
||||
|
||||
if is_openrouter_bypass:
|
||||
logging.warning(
|
||||
f"Forcing structured output for OpenRouter model {self.model}. "
|
||||
"Please ensure the model supports the expected response format."
|
||||
)
|
||||
|
||||
# Check if the model supports response schema
|
||||
is_schema_supported = supports_response_schema(
|
||||
model=self.model,
|
||||
custom_llm_provider=provider,
|
||||
)
|
||||
|
||||
if not (is_schema_supported or is_openrouter_bypass):
|
||||
raise ValueError(
|
||||
f"The model {self.model} does not support response_format for provider '{provider}'. "
|
||||
"Please remove response_format or use a supported model."
|
||||
f"Please remove response_format, use a supported model, or if you're using an "
|
||||
f"OpenRouter model that supports structured output, set force_structured_output=True."
|
||||
)
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
|
||||
@@ -257,6 +257,10 @@ def test_validate_call_params_no_response_format():
|
||||
|
||||
|
||||
def test_validate_call_params_openrouter_force_structured_output():
|
||||
"""
|
||||
Test that force_structured_output parameter allows bypassing response schema
|
||||
validation for OpenRouter models.
|
||||
"""
|
||||
class DummyResponse(BaseModel):
|
||||
a: int
|
||||
|
||||
@@ -282,6 +286,26 @@ def test_validate_call_params_openrouter_force_structured_output():
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_force_structured_output_bypasses_only_openrouter():
|
||||
"""
|
||||
Test that force_structured_output parameter only bypasses validation for
|
||||
OpenRouter models and not for other providers.
|
||||
"""
|
||||
class DummyResponse(BaseModel):
|
||||
a: int
|
||||
|
||||
# Test with non-OpenRouter provider and force_structured_output=True
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(
|
||||
model="otherprovider/model-name",
|
||||
response_format=DummyResponse,
|
||||
force_structured_output=True
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
|
||||
Reference in New Issue
Block a user