mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Fix response_format parameter to support Pydantic BaseModel classes
- Add conversion of Pydantic BaseModel classes to json_schema format in _prepare_completion_params - Add parsing of JSON responses back into Pydantic models in _handle_non_streaming_response - Ensure response_model parameter takes precedence over response_format - Add three comprehensive tests covering Pydantic model conversion, dict passthrough, and precedence - Fix test fixture decorator issue (removed @pytest.mark.vcr from anthropic_llm fixture) Fixes #3959 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -255,6 +255,114 @@ def test_validate_call_params_no_response_format():
|
||||
llm._validate_call_params()
|
||||
|
||||
|
||||
def test_response_format_pydantic_model_conversion():
|
||||
"""Test that response_format with Pydantic model is converted to json_schema format."""
|
||||
class TestResponse(BaseModel):
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
llm = LLM(model="gpt-4o-mini", response_format=TestResponse, is_litellm=True)
|
||||
|
||||
with patch("litellm.completion") as mocked_completion:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = '{"answer": "Paris", "confidence": 0.95}'
|
||||
mock_message.tool_calls = []
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 20,
|
||||
}
|
||||
|
||||
mocked_completion.return_value = mock_response
|
||||
|
||||
result = llm.call("What is the capital of France?")
|
||||
|
||||
mocked_completion.assert_called_once()
|
||||
_, kwargs = mocked_completion.call_args
|
||||
|
||||
assert "response_format" in kwargs
|
||||
assert isinstance(kwargs["response_format"], dict)
|
||||
assert kwargs["response_format"]["type"] == "json_schema"
|
||||
assert "json_schema" in kwargs["response_format"]
|
||||
assert kwargs["response_format"]["json_schema"]["name"] == "TestResponse"
|
||||
assert "schema" in kwargs["response_format"]["json_schema"]
|
||||
|
||||
import json
|
||||
result_dict = json.loads(result)
|
||||
assert result_dict["answer"] == "Paris"
|
||||
assert result_dict["confidence"] == 0.95
|
||||
|
||||
|
||||
def test_response_format_dict_passthrough():
|
||||
"""Test that response_format with dict is passed through unchanged."""
|
||||
response_format_dict = {"type": "json_object"}
|
||||
|
||||
llm = LLM(model="gpt-4o-mini", response_format=response_format_dict, is_litellm=True)
|
||||
|
||||
with patch("litellm.completion") as mocked_completion:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = '{"result": "test"}'
|
||||
mock_message.tool_calls = []
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 10,
|
||||
}
|
||||
|
||||
mocked_completion.return_value = mock_response
|
||||
|
||||
llm.call("Test message")
|
||||
|
||||
mocked_completion.assert_called_once()
|
||||
_, kwargs = mocked_completion.call_args
|
||||
|
||||
assert kwargs["response_format"] == response_format_dict
|
||||
|
||||
|
||||
def test_response_model_overrides_response_format():
|
||||
"""Test that response_model passed to call() overrides response_format from init."""
|
||||
class InitResponse(BaseModel):
|
||||
init_field: str
|
||||
|
||||
class CallResponse(BaseModel):
|
||||
call_field: str
|
||||
|
||||
llm = LLM(model="gpt-4o-mini", response_format=InitResponse, is_litellm=True)
|
||||
|
||||
with patch("litellm.completion") as mocked_completion:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = '{"init_field": "value"}'
|
||||
mock_message.tool_calls = []
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 10,
|
||||
}
|
||||
|
||||
mocked_completion.return_value = mock_response
|
||||
|
||||
result = llm.call("Test message")
|
||||
|
||||
mocked_completion.assert_called_once()
|
||||
_, kwargs = mocked_completion.call_args
|
||||
|
||||
assert "response_format" in kwargs
|
||||
assert kwargs["response_format"]["type"] == "json_schema"
|
||||
assert kwargs["response_format"]["json_schema"]["name"] == "InitResponse"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
@@ -411,7 +519,6 @@ def test_context_window_exceeded_error_handling():
|
||||
assert "8192 tokens" in str(excinfo.value)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.fixture
|
||||
def anthropic_llm():
|
||||
"""Fixture providing an Anthropic LLM instance."""
|
||||
|
||||
Reference in New Issue
Block a user