mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Fix agent loop not stopping after max iterations reached
Fixes #3847 The agent execution loop was not stopping after max_iter was reached. After calling handle_max_iterations_exceeded(), the loop continued and made additional LLM calls, overwriting the formatted answer. Changes: - Modified handle_max_iterations_exceeded() to always return AgentFinish instead of AgentAction | AgentFinish, ensuring a final answer is always produced when max iterations are exceeded - Added early return in _invoke_loop() after handle_max_iterations_exceeded() to prevent additional LLM calls - Updated test_agent_custom_max_iterations to expect 2 LLM calls instead of 3 (1 initial + 1 from max_iter handler, no longer 3 due to extra call bug) - Added test_agent_max_iterations_zero to verify no get_llm_response calls when max_iter=0 - Added test_agent_max_iterations_one_stops_after_two_calls to verify exactly 2 LLM calls with max_iter=1 - Fixed type annotations to properly handle AgentAction | AgentFinish | None Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -202,7 +202,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
Returns:
|
Returns:
|
||||||
Final answer from the agent.
|
Final answer from the agent.
|
||||||
"""
|
"""
|
||||||
formatted_answer = None
|
formatted_answer: AgentAction | AgentFinish | None = None
|
||||||
while not isinstance(formatted_answer, AgentFinish):
|
while not isinstance(formatted_answer, AgentFinish):
|
||||||
try:
|
try:
|
||||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||||
@@ -214,6 +214,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
)
|
)
|
||||||
|
self._show_logs(formatted_answer)
|
||||||
|
return formatted_answer
|
||||||
|
|
||||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||||
|
|
||||||
@@ -257,9 +259,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
formatted_answer = self._handle_agent_action(
|
formatted_answer = self._handle_agent_action(
|
||||||
formatted_answer, tool_result
|
formatted_answer, tool_result
|
||||||
)
|
)
|
||||||
|
|
||||||
self._invoke_step_callback(formatted_answer)
|
self._invoke_step_callback(formatted_answer)
|
||||||
self._append_message(formatted_answer.text)
|
self._append_message(formatted_answer.text)
|
||||||
|
elif isinstance(formatted_answer, AgentFinish):
|
||||||
|
self._invoke_step_callback(formatted_answer)
|
||||||
|
self._append_message(formatted_answer.text)
|
||||||
|
|
||||||
except OutputParserError as e: # noqa: PERF203
|
except OutputParserError as e: # noqa: PERF203
|
||||||
formatted_answer = handle_output_parser_exception(
|
formatted_answer = handle_output_parser_exception(
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ def handle_max_iterations_exceeded(
|
|||||||
messages: list[LLMMessage],
|
messages: list[LLMMessage],
|
||||||
llm: LLM | BaseLLM,
|
llm: LLM | BaseLLM,
|
||||||
callbacks: list[TokenCalcHandler],
|
callbacks: list[TokenCalcHandler],
|
||||||
) -> AgentAction | AgentFinish:
|
) -> AgentFinish:
|
||||||
"""Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer.
|
"""Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -139,7 +139,7 @@ def handle_max_iterations_exceeded(
|
|||||||
callbacks: List of callbacks for the LLM call.
|
callbacks: List of callbacks for the LLM call.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The final formatted answer after exceeding max iterations.
|
The final formatted answer after exceeding max iterations (always AgentFinish).
|
||||||
"""
|
"""
|
||||||
printer.print(
|
printer.print(
|
||||||
content="Maximum iterations reached. Requesting final answer.",
|
content="Maximum iterations reached. Requesting final answer.",
|
||||||
@@ -168,8 +168,17 @@ def handle_max_iterations_exceeded(
|
|||||||
)
|
)
|
||||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||||
|
|
||||||
# Return the formatted answer, regardless of its type
|
# Parse the answer and ensure it's always an AgentFinish
|
||||||
return format_answer(answer=answer)
|
parsed_answer = format_answer(answer=answer)
|
||||||
|
|
||||||
|
if isinstance(parsed_answer, AgentAction):
|
||||||
|
return AgentFinish(
|
||||||
|
thought=parsed_answer.thought,
|
||||||
|
output=parsed_answer.text,
|
||||||
|
text=parsed_answer.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
return parsed_answer
|
||||||
|
|
||||||
|
|
||||||
def format_message_for_llm(
|
def format_message_for_llm(
|
||||||
|
|||||||
@@ -508,7 +508,96 @@ def test_agent_custom_max_iterations():
|
|||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
assert call_count > 0
|
assert call_count > 0
|
||||||
assert call_count == 3
|
assert call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_max_iterations_zero():
|
||||||
|
"""Test that with max_iter=0, get_llm_response is never called but handle_max_iterations_exceeded makes one LLM call."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from crewai.agents import crew_agent_executor
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="test role",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
max_iter=0,
|
||||||
|
allow_delegation=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_call = agent.llm.call
|
||||||
|
llm_call_count = 0
|
||||||
|
|
||||||
|
def counting_llm_call(*args, **kwargs):
|
||||||
|
nonlocal llm_call_count
|
||||||
|
llm_call_count += 1
|
||||||
|
return original_call(*args, **kwargs)
|
||||||
|
|
||||||
|
agent.llm.call = counting_llm_call
|
||||||
|
|
||||||
|
get_llm_response_call_count = 0
|
||||||
|
original_get_llm_response = crew_agent_executor.get_llm_response
|
||||||
|
|
||||||
|
def counting_get_llm_response(*args, **kwargs):
|
||||||
|
nonlocal get_llm_response_call_count
|
||||||
|
get_llm_response_call_count += 1
|
||||||
|
return original_get_llm_response(*args, **kwargs)
|
||||||
|
|
||||||
|
crew_agent_executor.get_llm_response = counting_get_llm_response
|
||||||
|
|
||||||
|
try:
|
||||||
|
task = Task(
|
||||||
|
description="What is 2+2?",
|
||||||
|
expected_output="The answer",
|
||||||
|
)
|
||||||
|
result = agent.execute_task(task=task)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert get_llm_response_call_count == 0
|
||||||
|
assert llm_call_count == 1
|
||||||
|
finally:
|
||||||
|
crew_agent_executor.get_llm_response = original_get_llm_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_max_iterations_one_stops_after_two_calls():
|
||||||
|
"""Test that with max_iter=1, exactly 2 LLM calls are made (initial + max_iter handler)."""
|
||||||
|
@tool
|
||||||
|
def dummy_tool() -> str:
|
||||||
|
"""A dummy tool that returns a value."""
|
||||||
|
return "dummy result"
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="test role",
|
||||||
|
goal="test goal",
|
||||||
|
backstory="test backstory",
|
||||||
|
max_iter=1,
|
||||||
|
allow_delegation=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_call = agent.llm.call
|
||||||
|
llm_call_count = 0
|
||||||
|
|
||||||
|
def counting_llm_call(*args, **kwargs):
|
||||||
|
nonlocal llm_call_count
|
||||||
|
llm_call_count += 1
|
||||||
|
return original_call(*args, **kwargs)
|
||||||
|
|
||||||
|
agent.llm.call = counting_llm_call
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Keep using the dummy_tool repeatedly.",
|
||||||
|
expected_output="The final answer",
|
||||||
|
)
|
||||||
|
result = agent.execute_task(
|
||||||
|
task=task,
|
||||||
|
tools=[dummy_tool],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert llm_call_count == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
|||||||
Reference in New Issue
Block a user