mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
3 Commits
devin/1761
...
devin/1746
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0173a3ceaf | ||
|
|
d636593359 | ||
|
|
9b0fbd24ee |
@@ -204,6 +204,8 @@ LLM_CONTEXT_WINDOW_SIZES = {
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE = 8192
|
||||
CONTEXT_WINDOW_USAGE_RATIO = 0.75
|
||||
|
||||
OPENROUTER_PROVIDER = "openrouter"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
@@ -270,8 +272,41 @@ class LLM(BaseLLM):
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
stream: bool = False,
|
||||
force_structured_output: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize an LLM instance.
|
||||
|
||||
Args:
|
||||
model: The language model to use.
|
||||
timeout: The request timeout in seconds.
|
||||
temperature: The temperature to use for sampling.
|
||||
top_p: The cumulative probability for top-p sampling.
|
||||
n: The number of completions to generate.
|
||||
stop: A list of strings to stop generation when encountered.
|
||||
max_completion_tokens: The maximum number of tokens to generate.
|
||||
max_tokens: Alias for max_completion_tokens.
|
||||
presence_penalty: The presence penalty to use.
|
||||
frequency_penalty: The frequency penalty to use.
|
||||
logit_bias: The logit bias to use.
|
||||
response_format: The format to return the response in.
|
||||
seed: The random seed to use.
|
||||
logprobs: Whether to return log probabilities.
|
||||
top_logprobs: Whether to return the top log probabilities.
|
||||
base_url: The base URL to use.
|
||||
api_base: Alias for base_url.
|
||||
api_version: The API version to use.
|
||||
api_key: The API key to use.
|
||||
callbacks: A list of callbacks to use.
|
||||
reasoning_effort: The reasoning effort to use (e.g., "low", "medium", "high").
|
||||
stream: Whether to stream the response.
|
||||
force_structured_output: When True and using OpenRouter provider, bypasses
|
||||
response schema validation. Use with caution as it may lead to runtime
|
||||
errors if the model doesn't actually support structured outputs.
|
||||
Only use this if you're certain the model supports the expected format.
|
||||
Defaults to False.
|
||||
**kwargs: Additional parameters to pass to the LLM.
|
||||
"""
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
@@ -296,6 +331,7 @@ class LLM(BaseLLM):
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.force_structured_output = force_structured_output
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
@@ -991,14 +1027,32 @@ class LLM(BaseLLM):
|
||||
- "gemini/gemini-1.5-pro" yields "gemini"
|
||||
- If no slash is present, "openai" is assumed.
|
||||
"""
|
||||
if self.response_format is None:
|
||||
return
|
||||
|
||||
provider = self._get_custom_llm_provider()
|
||||
if self.response_format is not None and not supports_response_schema(
|
||||
|
||||
# Check if we're bypassing validation for OpenRouter
|
||||
is_openrouter = provider and provider.lower() == OPENROUTER_PROVIDER.lower()
|
||||
is_openrouter_bypass = is_openrouter and self.force_structured_output
|
||||
|
||||
# Check if the model supports response schema
|
||||
is_schema_supported = supports_response_schema(
|
||||
model=self.model,
|
||||
custom_llm_provider=provider,
|
||||
):
|
||||
)
|
||||
|
||||
if is_openrouter_bypass:
|
||||
print(
|
||||
f"Warning: Forcing structured output for OpenRouter model {self.model}. "
|
||||
"Please ensure the model supports the expected response format."
|
||||
)
|
||||
|
||||
if not (is_schema_supported or is_openrouter_bypass):
|
||||
raise ValueError(
|
||||
f"The model {self.model} does not support response_format for provider '{provider}'. "
|
||||
"Please remove response_format or use a supported model."
|
||||
f"Please remove response_format, use a supported model, or if you're using an "
|
||||
f"OpenRouter model that supports structured output, set force_structured_output=True."
|
||||
)
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
|
||||
@@ -256,6 +256,56 @@ def test_validate_call_params_no_response_format():
|
||||
llm._validate_call_params()
|
||||
|
||||
|
||||
def test_validate_call_params_openrouter_force_structured_output():
|
||||
"""
|
||||
Test that force_structured_output parameter allows bypassing response schema
|
||||
validation for OpenRouter models.
|
||||
"""
|
||||
class DummyResponse(BaseModel):
|
||||
a: int
|
||||
|
||||
# Test with OpenRouter and force_structured_output=True
|
||||
llm = LLM(
|
||||
model="openrouter/deepseek/deepseek-chat",
|
||||
response_format=DummyResponse,
|
||||
force_structured_output=True
|
||||
)
|
||||
# Should not raise any error with force_structured_output=True
|
||||
llm._validate_call_params()
|
||||
|
||||
# Test with OpenRouter and force_structured_output=False (default)
|
||||
# Patch supports_response_schema to simulate an unsupported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(
|
||||
model="openrouter/deepseek/deepseek-chat",
|
||||
response_format=DummyResponse,
|
||||
force_structured_output=False
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_force_structured_output_bypasses_only_openrouter():
|
||||
"""
|
||||
Test that force_structured_output parameter only bypasses validation for
|
||||
OpenRouter models and not for other providers.
|
||||
"""
|
||||
class DummyResponse(BaseModel):
|
||||
a: int
|
||||
|
||||
# Test with non-OpenRouter provider and force_structured_output=True
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(
|
||||
model="otherprovider/model-name",
|
||||
response_format=DummyResponse,
|
||||
force_structured_output=True
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
|
||||
Reference in New Issue
Block a user