mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
Refactor Azure OpenAI response_format validation based on PR feedback
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -127,6 +127,9 @@ def suppress_warnings():
|
||||
|
||||
|
||||
class LLM:
|
||||
# Azure OpenAI models that support JSON mode
|
||||
AZURE_JSON_SUPPORTED_MODELS = ["gpt-35-turbo", "gpt-4-turbo", "gpt-4o"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@@ -446,21 +449,40 @@ class LLM:
|
||||
if "/" in self.model:
|
||||
return self.model.split("/")[0]
|
||||
return "openai"
|
||||
|
||||
def _is_azure_json_supported_model(self) -> bool:
|
||||
"""
|
||||
Check if the current model is an Azure OpenAI model that supports JSON mode.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is an Azure OpenAI model that supports JSON mode, False otherwise.
|
||||
"""
|
||||
return any(prefix in self.model for prefix in self.AZURE_JSON_SUPPORTED_MODELS)
|
||||
|
||||
def _validate_call_params(self) -> None:
|
||||
"""
|
||||
Validate parameters before making a call. Currently this only checks if
|
||||
a response_format is provided and whether the model supports it.
|
||||
|
||||
Special handling for Azure OpenAI models that support JSON mode:
|
||||
- gpt-35-turbo
|
||||
- gpt-4-turbo
|
||||
- gpt-4o
|
||||
|
||||
The custom_llm_provider is dynamically determined from the model:
|
||||
- E.g., "openrouter/deepseek/deepseek-chat" yields "openrouter"
|
||||
- "gemini/gemini-1.5-pro" yields "gemini"
|
||||
- If no slash is present, "openai" is assumed.
|
||||
|
||||
Raises:
|
||||
ValueError: If response_format is used with unsupported models
|
||||
"""
|
||||
provider = self._get_custom_llm_provider()
|
||||
|
||||
# Special case for Azure OpenAI models that support JSON mode
|
||||
if (provider == "azure" and self.response_format is not None and
|
||||
any(model_prefix in self.model for model_prefix in ["gpt-35-turbo", "gpt-4-turbo", "gpt-4o"])):
|
||||
if (provider == "azure" and
|
||||
self.response_format is not None and
|
||||
self._is_azure_json_supported_model()):
|
||||
return # Skip validation for Azure OpenAI models that support JSON mode
|
||||
|
||||
if self.response_format is not None and not supports_response_schema(
|
||||
|
||||
@@ -253,32 +253,43 @@ def test_validate_call_params_no_response_format():
|
||||
llm._validate_call_params()
|
||||
|
||||
|
||||
def test_validate_call_params_azure_openai_supported():
|
||||
"""Test that Azure OpenAI models that should support JSON mode can use response_format."""
|
||||
# Test with gpt-4o
|
||||
llm = LLM(model="azure/gpt-4o", response_format={"type": "json_object"})
|
||||
# Should not raise any error
|
||||
llm._validate_call_params()
|
||||
class TestAzureOpenAIResponseFormat:
|
||||
@pytest.mark.parametrize("model_name", [
|
||||
"azure/gpt-4o",
|
||||
"azure/gpt-35-turbo",
|
||||
"azure/gpt-4-turbo"
|
||||
])
|
||||
def test_supported_models(self, model_name):
|
||||
"""Test that Azure OpenAI models that should support JSON mode can use response_format."""
|
||||
llm = LLM(model=model_name, response_format={"type": "json_object"})
|
||||
# Should not raise any error
|
||||
llm._validate_call_params()
|
||||
|
||||
# Test with gpt-35-turbo
|
||||
llm = LLM(model="azure/gpt-35-turbo", response_format={"type": "json_object"})
|
||||
# Should not raise any error
|
||||
llm._validate_call_params()
|
||||
def test_unsupported_model(self):
|
||||
"""Test that Azure OpenAI models that don't support JSON mode cannot use response_format."""
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(model="azure/text-davinci-003", response_format={"type": "json_object"})
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
|
||||
# Test with gpt-4-turbo
|
||||
llm = LLM(model="azure/gpt-4-turbo", response_format={"type": "json_object"})
|
||||
# Should not raise any error
|
||||
llm._validate_call_params()
|
||||
|
||||
|
||||
def test_validate_call_params_azure_openai_unsupported():
|
||||
"""Test that Azure OpenAI models that don't support JSON mode cannot use response_format."""
|
||||
# Test with a model that doesn't support JSON mode
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(model="azure/text-davinci-003", response_format={"type": "json_object"})
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
def test_validate_call_params_azure_invalid_response_format(self):
|
||||
"""Test that Azure OpenAI models validate the response_format type."""
|
||||
# This test is a placeholder as the current implementation doesn't validate the response_format type
|
||||
# If validation is added in the future, this test should be updated
|
||||
llm = LLM(model="azure/gpt-4-turbo", response_format={"type": "json_object"})
|
||||
# Should not raise any error
|
||||
llm._validate_call_params()
|
||||
|
||||
def test_validate_call_params_azure_none_provider(self):
|
||||
"""Test that non-Azure models with Azure model names don't skip validation."""
|
||||
# This test verifies that the provider check is working correctly
|
||||
with patch("crewai.llm.LLM._get_custom_llm_provider", return_value="openai"):
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(model="gpt-4-turbo", response_format={"type": "json_object"})
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
|
||||
Reference in New Issue
Block a user