diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 0c8a46214..bfc6f67e4 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -457,6 +457,12 @@ class LLM: - If no slash is present, "openai" is assumed. """ provider = self._get_custom_llm_provider() + + # Special case for Azure OpenAI models that support JSON mode + if (provider == "azure" and self.response_format is not None and + any(model_prefix in self.model for model_prefix in ["gpt-35-turbo", "gpt-4-turbo", "gpt-4o"])): + return # Skip validation for Azure OpenAI models that support JSON mode + if self.response_format is not None and not supports_response_schema( model=self.model, custom_llm_provider=provider, diff --git a/tests/llm_test.py b/tests/llm_test.py index 61aa1aced..1ee3e4385 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -253,6 +253,34 @@ def test_validate_call_params_no_response_format(): llm._validate_call_params() +def test_validate_call_params_azure_openai_supported(): + """Test that Azure OpenAI models that should support JSON mode can use response_format.""" + # Test with gpt-4o + llm = LLM(model="azure/gpt-4o", response_format={"type": "json_object"}) + # Should not raise any error + llm._validate_call_params() + + # Test with gpt-35-turbo + llm = LLM(model="azure/gpt-35-turbo", response_format={"type": "json_object"}) + # Should not raise any error + llm._validate_call_params() + + # Test with gpt-4-turbo + llm = LLM(model="azure/gpt-4-turbo", response_format={"type": "json_object"}) + # Should not raise any error + llm._validate_call_params() + + +def test_validate_call_params_azure_openai_unsupported(): + """Test that Azure OpenAI models that don't support JSON mode cannot use response_format.""" + # Test with a model that doesn't support JSON mode + with patch("crewai.llm.supports_response_schema", return_value=False): + llm = LLM(model="azure/text-davinci-003", response_format={"type": "json_object"}) + with pytest.raises(ValueError) as excinfo: + llm._validate_call_params() + assert "does not support response_format" in str(excinfo.value) + + @pytest.mark.vcr(filter_headers=["authorization"]) def test_o3_mini_reasoning_effort_high(): llm = LLM(