diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index b0cf42091..1a81fff28 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -589,12 +589,14 @@ class LLM(BaseLLM): self, messages: str | list[LLMMessage], tools: list[dict[str, BaseTool]] | None = None, + response_model: type[BaseModel] | None = None, ) -> dict[str, Any]: """Prepare parameters for the completion call. Args: messages: Input messages for the LLM tools: Optional list of tool schemas + response_model: Optional response model that overrides self.response_format Returns: Dict[str, Any]: Parameters for the completion call @@ -604,7 +606,25 @@ class LLM(BaseLLM): messages = [{"role": "user", "content": messages}] formatted_messages = self._format_messages_for_provider(messages) - # --- 2) Prepare the parameters for the completion call + # --- 2) Handle response_format conversion for Pydantic models + # If response_model is passed to call(), it takes precedence over self.response_format + response_format_param = None + if response_model is None and self.response_format is not None: + if isinstance(self.response_format, type) and issubclass( + self.response_format, BaseModel + ): + # Convert Pydantic model to json_schema format for LiteLLM + response_format_param = { + "type": "json_schema", + "json_schema": { + "name": self.response_format.__name__, + "schema": self.response_format.model_json_schema(), + }, + } + else: + response_format_param = self.response_format + + # --- 3) Prepare the parameters for the completion call params = { "model": self.model, "messages": formatted_messages, @@ -617,7 +637,7 @@ class LLM(BaseLLM): "presence_penalty": self.presence_penalty, "frequency_penalty": self.frequency_penalty, "logit_bias": self.logit_bias, - "response_format": self.response_format, + "response_format": response_format_param, "seed": self.seed, "logprobs": self.logprobs, "top_logprobs": self.top_logprobs, @@ -1115,8 +1135,32 @@ class LLM(BaseLLM): # --- 4) Check for tool calls tool_calls = getattr(response_message, "tool_calls", []) - # --- 5) If no tool calls or no available functions, return the text response directly as long as there is a text response + # --- 5) If no tool calls or no available functions, handle text response if (not tool_calls or not available_functions) and text_response: + # If self.response_format is a Pydantic class, parse the response + if ( + self.response_format is not None + and isinstance(self.response_format, type) + and issubclass(self.response_format, BaseModel) + ): + try: + parsed_model = self.response_format.model_validate_json( + text_response + ) + structured_response = parsed_model.model_dump_json() + self._handle_emit_call_events( + response=structured_response, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=params["messages"], + ) + return structured_response + except Exception as e: + logging.warning( + f"Failed to parse response into {self.response_format.__name__}: {e}" + ) + self._handle_emit_call_events( response=text_response, call_type=LLMCallType.LLM_CALL, @@ -1302,7 +1346,7 @@ class LLM(BaseLLM): self.set_callbacks(callbacks) try: # --- 6) Prepare parameters for the completion call - params = self._prepare_completion_params(messages, tools) + params = self._prepare_completion_params(messages, tools, response_model) # --- 7) Make the completion call and handle response if self.stream: return self._handle_streaming_response( diff --git a/lib/crewai/tests/test_llm.py b/lib/crewai/tests/test_llm.py index ad3dd9963..7d0c04706 100644 --- a/lib/crewai/tests/test_llm.py +++ b/lib/crewai/tests/test_llm.py @@ -255,6 +255,114 @@ def test_validate_call_params_no_response_format(): llm._validate_call_params() +def test_response_format_pydantic_model_conversion(): + """Test that response_format with Pydantic model is converted to json_schema format.""" + class TestResponse(BaseModel): + answer: str + confidence: float + + llm = LLM(model="gpt-4o-mini", response_format=TestResponse, is_litellm=True) + + with patch("litellm.completion") as mocked_completion: + mock_message = MagicMock() + mock_message.content = '{"answer": "Paris", "confidence": 0.95}' + mock_message.tool_calls = [] + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = { + "prompt_tokens": 10, + "completion_tokens": 10, + "total_tokens": 20, + } + + mocked_completion.return_value = mock_response + + result = llm.call("What is the capital of France?") + + mocked_completion.assert_called_once() + _, kwargs = mocked_completion.call_args + + assert "response_format" in kwargs + assert isinstance(kwargs["response_format"], dict) + assert kwargs["response_format"]["type"] == "json_schema" + assert "json_schema" in kwargs["response_format"] + assert kwargs["response_format"]["json_schema"]["name"] == "TestResponse" + assert "schema" in kwargs["response_format"]["json_schema"] + + import json + result_dict = json.loads(result) + assert result_dict["answer"] == "Paris" + assert result_dict["confidence"] == 0.95 + + +def test_response_format_dict_passthrough(): + """Test that response_format with dict is passed through unchanged.""" + response_format_dict = {"type": "json_object"} + + llm = LLM(model="gpt-4o-mini", response_format=response_format_dict, is_litellm=True) + + with patch("litellm.completion") as mocked_completion: + mock_message = MagicMock() + mock_message.content = '{"result": "test"}' + mock_message.tool_calls = [] + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = { + "prompt_tokens": 5, + "completion_tokens": 5, + "total_tokens": 10, + } + + mocked_completion.return_value = mock_response + + llm.call("Test message") + + mocked_completion.assert_called_once() + _, kwargs = mocked_completion.call_args + + assert kwargs["response_format"] == response_format_dict + + +def test_response_model_overrides_response_format(): + """Test that response_model passed to call() overrides response_format from init.""" + class InitResponse(BaseModel): + init_field: str + + class CallResponse(BaseModel): + call_field: str + + llm = LLM(model="gpt-4o-mini", response_format=InitResponse, is_litellm=True) + + with patch("litellm.completion") as mocked_completion: + mock_message = MagicMock() + mock_message.content = '{"init_field": "value"}' + mock_message.tool_calls = [] + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = { + "prompt_tokens": 5, + "completion_tokens": 5, + "total_tokens": 10, + } + + mocked_completion.return_value = mock_response + + result = llm.call("Test message") + + mocked_completion.assert_called_once() + _, kwargs = mocked_completion.call_args + + assert "response_format" in kwargs + assert kwargs["response_format"]["type"] == "json_schema" + assert kwargs["response_format"]["json_schema"]["name"] == "InitResponse" + + @pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"]) @pytest.mark.parametrize( "model", @@ -411,7 +519,6 @@ def test_context_window_exceeded_error_handling(): assert "8192 tokens" in str(excinfo.value) -@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.fixture def anthropic_llm(): """Fixture providing an Anthropic LLM instance."""