Fix structured output leaks in tool-calling loops

This commit is contained in:
lorenzejay
2026-05-21 15:33:22 -07:00
parent c3ef622ec6
commit 4b08a8308d
5 changed files with 118 additions and 43 deletions

View File

@@ -28,6 +28,7 @@ from pydantic import (
ConfigDict,
Field,
PrivateAttr,
ValidationError,
model_validator,
)
from pydantic.functional_serializers import PlainSerializer
@@ -1691,24 +1692,30 @@ class Agent(BaseAgent):
elif response_format:
raw_output = str(output) if not isinstance(output, str) else output
try:
model_schema = generate_model_description(response_format)
schema = json.dumps(model_schema, indent=2)
instructions = I18N_DEFAULT.slice("formatted_task_instructions").format(
output_format=schema
)
converter = Converter(
llm=cast(BaseLLM, self.llm),
text=raw_output,
model=response_format,
instructions=instructions,
)
conversion_result = converter.to_pydantic()
if isinstance(conversion_result, BaseModel):
formatted_result = conversion_result
except ConverterError:
formatted_result = response_format.model_validate_json(raw_output)
except ValidationError:
pass
if formatted_result is None:
try:
model_schema = generate_model_description(response_format)
schema = json.dumps(model_schema, indent=2)
instructions = I18N_DEFAULT.slice(
"formatted_task_instructions"
).format(output_format=schema)
converter = Converter(
llm=cast(BaseLLM, self.llm),
text=raw_output,
model=response_format,
instructions=instructions,
)
conversion_result = converter.to_pydantic()
if isinstance(conversion_result, BaseModel):
formatted_result = conversion_result
except ConverterError:
pass
else:
raw_output = str(output) if not isinstance(output, str) else output

View File

@@ -350,6 +350,10 @@ class CrewAgentExecutor(BaseAgentExecutor):
enforce_rpm_limit(self.request_within_rpm_limit)
effective_response_model = (
None if self.original_tools else self.response_model
)
answer = get_llm_response(
llm=cast("BaseLLM", self.llm),
messages=self.messages,
@@ -357,11 +361,11 @@ class CrewAgentExecutor(BaseAgentExecutor):
printer=PRINTER,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
response_model=effective_response_model,
executor_context=self,
verbose=self.agent.verbose,
)
if self.response_model is not None:
if effective_response_model is not None:
try:
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
@@ -502,7 +506,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
available_functions=None,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
response_model=None,
executor_context=self,
verbose=self.agent.verbose,
)
@@ -1161,6 +1165,10 @@ class CrewAgentExecutor(BaseAgentExecutor):
enforce_rpm_limit(self.request_within_rpm_limit)
effective_response_model = (
None if self.original_tools else self.response_model
)
answer = await aget_llm_response(
llm=cast("BaseLLM", self.llm),
messages=self.messages,
@@ -1168,12 +1176,12 @@ class CrewAgentExecutor(BaseAgentExecutor):
printer=PRINTER,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
response_model=effective_response_model,
executor_context=self,
verbose=self.agent.verbose,
)
if self.response_model is not None:
if effective_response_model is not None:
try:
if isinstance(answer, BaseModel):
output_json = answer.model_dump_json()
@@ -1314,7 +1322,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
available_functions=None,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
response_model=None,
executor_context=self,
verbose=self.agent.verbose,
)

View File

@@ -1224,6 +1224,10 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
try:
enforce_rpm_limit(self.request_within_rpm_limit)
effective_response_model = (
None if self.original_tools else self.response_model
)
answer = get_llm_response(
llm=self.llm,
messages=list(self.state.messages),
@@ -1231,7 +1235,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
printer=PRINTER,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
response_model=effective_response_model,
executor_context=self,
verbose=self.agent.verbose,
)
@@ -1319,7 +1323,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
available_functions=None,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
response_model=None,
executor_context=self,
verbose=self.agent.verbose,
)

View File

@@ -23,6 +23,7 @@ from pydantic import (
BaseModel,
Field,
PrivateAttr,
ValidationError,
field_validator,
model_validator,
)
@@ -639,29 +640,37 @@ class LiteAgent(FlowTrackable, BaseModel):
formatted_result = agent_finish.output
elif active_response_format:
try:
model_schema = generate_model_description(active_response_format)
schema = json.dumps(model_schema, indent=2)
instructions = I18N_DEFAULT.slice("formatted_task_instructions").format(
output_format=schema
formatted_result = active_response_format.model_validate_json(
str(agent_finish.output)
)
except ValidationError:
pass
converter = Converter(
llm=self.llm,
text=agent_finish.output,
model=active_response_format,
instructions=instructions,
)
if formatted_result is None:
try:
model_schema = generate_model_description(active_response_format)
schema = json.dumps(model_schema, indent=2)
instructions = I18N_DEFAULT.slice(
"formatted_task_instructions"
).format(output_format=schema)
result = converter.to_pydantic()
if isinstance(result, BaseModel):
formatted_result = result
except ConverterError as e:
if self.verbose:
PRINTER.print(
content=f"Failed to parse output into response format after retries: {e.message}",
color="yellow",
converter = Converter(
llm=self.llm,
text=agent_finish.output,
model=active_response_format,
instructions=instructions,
)
result = converter.to_pydantic()
if isinstance(result, BaseModel):
formatted_result = result
except ConverterError as e:
if self.verbose:
PRINTER.print(
content=f"Failed to parse output into response format after retries: {e.message}",
color="yellow",
)
# Calculate token usage metrics
if isinstance(self.llm, BaseLLM):
usage_metrics = self.llm.get_token_usage_summary()

View File

@@ -12,6 +12,7 @@ from typing import Any
from unittest.mock import AsyncMock, Mock, patch
import pytest
from pydantic import BaseModel
from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler
from crewai.agents.step_executor import StepExecutor
@@ -108,6 +109,9 @@ class TestAgentExecutorState:
class TestAgentExecutor:
"""Test AgentExecutor class."""
class StructuredResult(BaseModel):
value: str
@pytest.fixture
def mock_dependencies(self):
"""Create mock dependencies for executor."""
@@ -215,6 +219,49 @@ class TestAgentExecutor:
assert result == "check_iteration"
def test_call_llm_and_parse_does_not_pass_response_model_with_tools(
self, mock_dependencies
):
"""Structured output should not be requested during ReAct tool loops."""
executor = _build_executor(
**mock_dependencies,
original_tools=[Mock()],
response_model=self.StructuredResult,
callbacks=[],
)
executor.state.messages = [{"role": "user", "content": "Use a tool"}]
with patch(
"crewai.experimental.agent_executor.get_llm_response",
return_value="Thought: done\nFinal Answer: complete",
) as get_llm_response_mock:
result = executor.call_llm_and_parse()
assert result == "parsed"
assert get_llm_response_mock.call_args.kwargs["response_model"] is None
def test_call_llm_native_tools_does_not_pass_response_model_with_tools(
self, mock_dependencies
):
"""Structured output should not be requested during native tool calls."""
executor = _build_executor(
**mock_dependencies,
original_tools=[Mock()],
response_model=self.StructuredResult,
callbacks=[],
)
executor._openai_tools = [{"type": "function", "function": {"name": "lookup"}}]
executor.state.messages = [{"role": "user", "content": "Use a tool"}]
with patch(
"crewai.experimental.agent_executor.get_llm_response",
return_value="complete",
) as get_llm_response_mock:
result = executor.call_llm_native_tools()
assert result == "native_finished"
assert get_llm_response_mock.call_args.kwargs["response_model"] is None
def test_finalize_success(self, mock_dependencies):
"""Test finalize with valid AgentFinish."""
with patch.object(AgentExecutor, "_show_logs") as mock_show_logs: