mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 14:09:24 +00:00
Fix structured output leaks in tool-calling loops
This commit is contained in:
@@ -28,6 +28,7 @@ from pydantic import (
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic.functional_serializers import PlainSerializer
|
||||
@@ -1691,24 +1692,30 @@ class Agent(BaseAgent):
|
||||
elif response_format:
|
||||
raw_output = str(output) if not isinstance(output, str) else output
|
||||
try:
|
||||
model_schema = generate_model_description(response_format)
|
||||
schema = json.dumps(model_schema, indent=2)
|
||||
instructions = I18N_DEFAULT.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
|
||||
converter = Converter(
|
||||
llm=cast(BaseLLM, self.llm),
|
||||
text=raw_output,
|
||||
model=response_format,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
conversion_result = converter.to_pydantic()
|
||||
if isinstance(conversion_result, BaseModel):
|
||||
formatted_result = conversion_result
|
||||
except ConverterError:
|
||||
formatted_result = response_format.model_validate_json(raw_output)
|
||||
except ValidationError:
|
||||
pass
|
||||
|
||||
if formatted_result is None:
|
||||
try:
|
||||
model_schema = generate_model_description(response_format)
|
||||
schema = json.dumps(model_schema, indent=2)
|
||||
instructions = I18N_DEFAULT.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
converter = Converter(
|
||||
llm=cast(BaseLLM, self.llm),
|
||||
text=raw_output,
|
||||
model=response_format,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
conversion_result = converter.to_pydantic()
|
||||
if isinstance(conversion_result, BaseModel):
|
||||
formatted_result = conversion_result
|
||||
except ConverterError:
|
||||
pass
|
||||
else:
|
||||
raw_output = str(output) if not isinstance(output, str) else output
|
||||
|
||||
|
||||
@@ -350,6 +350,10 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
effective_response_model = (
|
||||
None if self.original_tools else self.response_model
|
||||
)
|
||||
|
||||
answer = get_llm_response(
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
messages=self.messages,
|
||||
@@ -357,11 +361,11 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
printer=PRINTER,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
response_model=effective_response_model,
|
||||
executor_context=self,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
if self.response_model is not None:
|
||||
if effective_response_model is not None:
|
||||
try:
|
||||
if isinstance(answer, BaseModel):
|
||||
output_json = answer.model_dump_json()
|
||||
@@ -502,7 +506,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
available_functions=None,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
response_model=None,
|
||||
executor_context=self,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
@@ -1161,6 +1165,10 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
effective_response_model = (
|
||||
None if self.original_tools else self.response_model
|
||||
)
|
||||
|
||||
answer = await aget_llm_response(
|
||||
llm=cast("BaseLLM", self.llm),
|
||||
messages=self.messages,
|
||||
@@ -1168,12 +1176,12 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
printer=PRINTER,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
response_model=effective_response_model,
|
||||
executor_context=self,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
|
||||
if self.response_model is not None:
|
||||
if effective_response_model is not None:
|
||||
try:
|
||||
if isinstance(answer, BaseModel):
|
||||
output_json = answer.model_dump_json()
|
||||
@@ -1314,7 +1322,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
available_functions=None,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
response_model=None,
|
||||
executor_context=self,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
|
||||
@@ -1224,6 +1224,10 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
try:
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
effective_response_model = (
|
||||
None if self.original_tools else self.response_model
|
||||
)
|
||||
|
||||
answer = get_llm_response(
|
||||
llm=self.llm,
|
||||
messages=list(self.state.messages),
|
||||
@@ -1231,7 +1235,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
printer=PRINTER,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
response_model=effective_response_model,
|
||||
executor_context=self,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
@@ -1319,7 +1323,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
available_functions=None,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
response_model=None,
|
||||
executor_context=self,
|
||||
verbose=self.agent.verbose,
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
@@ -639,29 +640,37 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
formatted_result = agent_finish.output
|
||||
elif active_response_format:
|
||||
try:
|
||||
model_schema = generate_model_description(active_response_format)
|
||||
schema = json.dumps(model_schema, indent=2)
|
||||
instructions = I18N_DEFAULT.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
formatted_result = active_response_format.model_validate_json(
|
||||
str(agent_finish.output)
|
||||
)
|
||||
except ValidationError:
|
||||
pass
|
||||
|
||||
converter = Converter(
|
||||
llm=self.llm,
|
||||
text=agent_finish.output,
|
||||
model=active_response_format,
|
||||
instructions=instructions,
|
||||
)
|
||||
if formatted_result is None:
|
||||
try:
|
||||
model_schema = generate_model_description(active_response_format)
|
||||
schema = json.dumps(model_schema, indent=2)
|
||||
instructions = I18N_DEFAULT.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
result = converter.to_pydantic()
|
||||
if isinstance(result, BaseModel):
|
||||
formatted_result = result
|
||||
except ConverterError as e:
|
||||
if self.verbose:
|
||||
PRINTER.print(
|
||||
content=f"Failed to parse output into response format after retries: {e.message}",
|
||||
color="yellow",
|
||||
converter = Converter(
|
||||
llm=self.llm,
|
||||
text=agent_finish.output,
|
||||
model=active_response_format,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
result = converter.to_pydantic()
|
||||
if isinstance(result, BaseModel):
|
||||
formatted_result = result
|
||||
except ConverterError as e:
|
||||
if self.verbose:
|
||||
PRINTER.print(
|
||||
content=f"Failed to parse output into response format after retries: {e.message}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
# Calculate token usage metrics
|
||||
if isinstance(self.llm, BaseLLM):
|
||||
usage_metrics = self.llm.get_token_usage_summary()
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler
|
||||
from crewai.agents.step_executor import StepExecutor
|
||||
@@ -108,6 +109,9 @@ class TestAgentExecutorState:
|
||||
class TestAgentExecutor:
|
||||
"""Test AgentExecutor class."""
|
||||
|
||||
class StructuredResult(BaseModel):
|
||||
value: str
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Create mock dependencies for executor."""
|
||||
@@ -215,6 +219,49 @@ class TestAgentExecutor:
|
||||
|
||||
assert result == "check_iteration"
|
||||
|
||||
def test_call_llm_and_parse_does_not_pass_response_model_with_tools(
|
||||
self, mock_dependencies
|
||||
):
|
||||
"""Structured output should not be requested during ReAct tool loops."""
|
||||
executor = _build_executor(
|
||||
**mock_dependencies,
|
||||
original_tools=[Mock()],
|
||||
response_model=self.StructuredResult,
|
||||
callbacks=[],
|
||||
)
|
||||
executor.state.messages = [{"role": "user", "content": "Use a tool"}]
|
||||
|
||||
with patch(
|
||||
"crewai.experimental.agent_executor.get_llm_response",
|
||||
return_value="Thought: done\nFinal Answer: complete",
|
||||
) as get_llm_response_mock:
|
||||
result = executor.call_llm_and_parse()
|
||||
|
||||
assert result == "parsed"
|
||||
assert get_llm_response_mock.call_args.kwargs["response_model"] is None
|
||||
|
||||
def test_call_llm_native_tools_does_not_pass_response_model_with_tools(
|
||||
self, mock_dependencies
|
||||
):
|
||||
"""Structured output should not be requested during native tool calls."""
|
||||
executor = _build_executor(
|
||||
**mock_dependencies,
|
||||
original_tools=[Mock()],
|
||||
response_model=self.StructuredResult,
|
||||
callbacks=[],
|
||||
)
|
||||
executor._openai_tools = [{"type": "function", "function": {"name": "lookup"}}]
|
||||
executor.state.messages = [{"role": "user", "content": "Use a tool"}]
|
||||
|
||||
with patch(
|
||||
"crewai.experimental.agent_executor.get_llm_response",
|
||||
return_value="complete",
|
||||
) as get_llm_response_mock:
|
||||
result = executor.call_llm_native_tools()
|
||||
|
||||
assert result == "native_finished"
|
||||
assert get_llm_response_mock.call_args.kwargs["response_model"] is None
|
||||
|
||||
def test_finalize_success(self, mock_dependencies):
|
||||
"""Test finalize with valid AgentFinish."""
|
||||
with patch.object(AgentExecutor, "_show_logs") as mock_show_logs:
|
||||
|
||||
Reference in New Issue
Block a user