mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Fix response_format parameter to support Pydantic BaseModel classes
- Add conversion of Pydantic BaseModel classes to json_schema format in _prepare_completion_params - Add parsing of JSON responses back into Pydantic models in _handle_non_streaming_response - Ensure response_model parameter takes precedence over response_format - Add three comprehensive tests covering Pydantic model conversion, dict passthrough, and precedence - Fix test fixture decorator issue (removed @pytest.mark.vcr from anthropic_llm fixture) Fixes #3959 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -589,12 +589,14 @@ class LLM(BaseLLM):
|
|||||||
self,
|
self,
|
||||||
messages: str | list[LLMMessage],
|
messages: str | list[LLMMessage],
|
||||||
tools: list[dict[str, BaseTool]] | None = None,
|
tools: list[dict[str, BaseTool]] | None = None,
|
||||||
|
response_model: type[BaseModel] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Prepare parameters for the completion call.
|
"""Prepare parameters for the completion call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: Input messages for the LLM
|
messages: Input messages for the LLM
|
||||||
tools: Optional list of tool schemas
|
tools: Optional list of tool schemas
|
||||||
|
response_model: Optional response model that overrides self.response_format
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: Parameters for the completion call
|
Dict[str, Any]: Parameters for the completion call
|
||||||
@@ -604,7 +606,25 @@ class LLM(BaseLLM):
|
|||||||
messages = [{"role": "user", "content": messages}]
|
messages = [{"role": "user", "content": messages}]
|
||||||
formatted_messages = self._format_messages_for_provider(messages)
|
formatted_messages = self._format_messages_for_provider(messages)
|
||||||
|
|
||||||
# --- 2) Prepare the parameters for the completion call
|
# --- 2) Handle response_format conversion for Pydantic models
|
||||||
|
# If response_model is passed to call(), it takes precedence over self.response_format
|
||||||
|
response_format_param = None
|
||||||
|
if response_model is None and self.response_format is not None:
|
||||||
|
if isinstance(self.response_format, type) and issubclass(
|
||||||
|
self.response_format, BaseModel
|
||||||
|
):
|
||||||
|
# Convert Pydantic model to json_schema format for LiteLLM
|
||||||
|
response_format_param = {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": self.response_format.__name__,
|
||||||
|
"schema": self.response_format.model_json_schema(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
response_format_param = self.response_format
|
||||||
|
|
||||||
|
# --- 3) Prepare the parameters for the completion call
|
||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": formatted_messages,
|
"messages": formatted_messages,
|
||||||
@@ -617,7 +637,7 @@ class LLM(BaseLLM):
|
|||||||
"presence_penalty": self.presence_penalty,
|
"presence_penalty": self.presence_penalty,
|
||||||
"frequency_penalty": self.frequency_penalty,
|
"frequency_penalty": self.frequency_penalty,
|
||||||
"logit_bias": self.logit_bias,
|
"logit_bias": self.logit_bias,
|
||||||
"response_format": self.response_format,
|
"response_format": response_format_param,
|
||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
"logprobs": self.logprobs,
|
"logprobs": self.logprobs,
|
||||||
"top_logprobs": self.top_logprobs,
|
"top_logprobs": self.top_logprobs,
|
||||||
@@ -1115,8 +1135,32 @@ class LLM(BaseLLM):
|
|||||||
# --- 4) Check for tool calls
|
# --- 4) Check for tool calls
|
||||||
tool_calls = getattr(response_message, "tool_calls", [])
|
tool_calls = getattr(response_message, "tool_calls", [])
|
||||||
|
|
||||||
# --- 5) If no tool calls or no available functions, return the text response directly as long as there is a text response
|
# --- 5) If no tool calls or no available functions, handle text response
|
||||||
if (not tool_calls or not available_functions) and text_response:
|
if (not tool_calls or not available_functions) and text_response:
|
||||||
|
# If self.response_format is a Pydantic class, parse the response
|
||||||
|
if (
|
||||||
|
self.response_format is not None
|
||||||
|
and isinstance(self.response_format, type)
|
||||||
|
and issubclass(self.response_format, BaseModel)
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
parsed_model = self.response_format.model_validate_json(
|
||||||
|
text_response
|
||||||
|
)
|
||||||
|
structured_response = parsed_model.model_dump_json()
|
||||||
|
self._handle_emit_call_events(
|
||||||
|
response=structured_response,
|
||||||
|
call_type=LLMCallType.LLM_CALL,
|
||||||
|
from_task=from_task,
|
||||||
|
from_agent=from_agent,
|
||||||
|
messages=params["messages"],
|
||||||
|
)
|
||||||
|
return structured_response
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(
|
||||||
|
f"Failed to parse response into {self.response_format.__name__}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
self._handle_emit_call_events(
|
self._handle_emit_call_events(
|
||||||
response=text_response,
|
response=text_response,
|
||||||
call_type=LLMCallType.LLM_CALL,
|
call_type=LLMCallType.LLM_CALL,
|
||||||
@@ -1302,7 +1346,7 @@ class LLM(BaseLLM):
|
|||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
try:
|
try:
|
||||||
# --- 6) Prepare parameters for the completion call
|
# --- 6) Prepare parameters for the completion call
|
||||||
params = self._prepare_completion_params(messages, tools)
|
params = self._prepare_completion_params(messages, tools, response_model)
|
||||||
# --- 7) Make the completion call and handle response
|
# --- 7) Make the completion call and handle response
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return self._handle_streaming_response(
|
return self._handle_streaming_response(
|
||||||
|
|||||||
@@ -255,6 +255,114 @@ def test_validate_call_params_no_response_format():
|
|||||||
llm._validate_call_params()
|
llm._validate_call_params()
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_format_pydantic_model_conversion():
|
||||||
|
"""Test that response_format with Pydantic model is converted to json_schema format."""
|
||||||
|
class TestResponse(BaseModel):
|
||||||
|
answer: str
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
llm = LLM(model="gpt-4o-mini", response_format=TestResponse, is_litellm=True)
|
||||||
|
|
||||||
|
with patch("litellm.completion") as mocked_completion:
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.content = '{"answer": "Paris", "confidence": 0.95}'
|
||||||
|
mock_message.tool_calls = []
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_response.usage = {
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"total_tokens": 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
mocked_completion.return_value = mock_response
|
||||||
|
|
||||||
|
result = llm.call("What is the capital of France?")
|
||||||
|
|
||||||
|
mocked_completion.assert_called_once()
|
||||||
|
_, kwargs = mocked_completion.call_args
|
||||||
|
|
||||||
|
assert "response_format" in kwargs
|
||||||
|
assert isinstance(kwargs["response_format"], dict)
|
||||||
|
assert kwargs["response_format"]["type"] == "json_schema"
|
||||||
|
assert "json_schema" in kwargs["response_format"]
|
||||||
|
assert kwargs["response_format"]["json_schema"]["name"] == "TestResponse"
|
||||||
|
assert "schema" in kwargs["response_format"]["json_schema"]
|
||||||
|
|
||||||
|
import json
|
||||||
|
result_dict = json.loads(result)
|
||||||
|
assert result_dict["answer"] == "Paris"
|
||||||
|
assert result_dict["confidence"] == 0.95
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_format_dict_passthrough():
|
||||||
|
"""Test that response_format with dict is passed through unchanged."""
|
||||||
|
response_format_dict = {"type": "json_object"}
|
||||||
|
|
||||||
|
llm = LLM(model="gpt-4o-mini", response_format=response_format_dict, is_litellm=True)
|
||||||
|
|
||||||
|
with patch("litellm.completion") as mocked_completion:
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.content = '{"result": "test"}'
|
||||||
|
mock_message.tool_calls = []
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_response.usage = {
|
||||||
|
"prompt_tokens": 5,
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"total_tokens": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
mocked_completion.return_value = mock_response
|
||||||
|
|
||||||
|
llm.call("Test message")
|
||||||
|
|
||||||
|
mocked_completion.assert_called_once()
|
||||||
|
_, kwargs = mocked_completion.call_args
|
||||||
|
|
||||||
|
assert kwargs["response_format"] == response_format_dict
|
||||||
|
|
||||||
|
|
||||||
|
def test_response_model_overrides_response_format():
|
||||||
|
"""Test that response_model passed to call() overrides response_format from init."""
|
||||||
|
class InitResponse(BaseModel):
|
||||||
|
init_field: str
|
||||||
|
|
||||||
|
class CallResponse(BaseModel):
|
||||||
|
call_field: str
|
||||||
|
|
||||||
|
llm = LLM(model="gpt-4o-mini", response_format=InitResponse, is_litellm=True)
|
||||||
|
|
||||||
|
with patch("litellm.completion") as mocked_completion:
|
||||||
|
mock_message = MagicMock()
|
||||||
|
mock_message.content = '{"init_field": "value"}'
|
||||||
|
mock_message.tool_calls = []
|
||||||
|
mock_choice = MagicMock()
|
||||||
|
mock_choice.message = mock_message
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [mock_choice]
|
||||||
|
mock_response.usage = {
|
||||||
|
"prompt_tokens": 5,
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"total_tokens": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
mocked_completion.return_value = mock_response
|
||||||
|
|
||||||
|
result = llm.call("Test message")
|
||||||
|
|
||||||
|
mocked_completion.assert_called_once()
|
||||||
|
_, kwargs = mocked_completion.call_args
|
||||||
|
|
||||||
|
assert "response_format" in kwargs
|
||||||
|
assert kwargs["response_format"]["type"] == "json_schema"
|
||||||
|
assert kwargs["response_format"]["json_schema"]["name"] == "InitResponse"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model",
|
"model",
|
||||||
@@ -411,7 +519,6 @@ def test_context_window_exceeded_error_handling():
|
|||||||
assert "8192 tokens" in str(excinfo.value)
|
assert "8192 tokens" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def anthropic_llm():
|
def anthropic_llm():
|
||||||
"""Fixture providing an Anthropic LLM instance."""
|
"""Fixture providing an Anthropic LLM instance."""
|
||||||
|
|||||||
Reference in New Issue
Block a user