mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Fix agent loop not stopping after max iterations reached
Fixes #3847 The agent execution loop was not stopping after max_iter was reached. After calling handle_max_iterations_exceeded(), the loop continued and made additional LLM calls, overwriting the formatted answer. Changes: - Modified handle_max_iterations_exceeded() to always return AgentFinish instead of AgentAction | AgentFinish, ensuring a final answer is always produced when max iterations are exceeded - Added early return in _invoke_loop() after handle_max_iterations_exceeded() to prevent additional LLM calls - Updated test_agent_custom_max_iterations to expect 2 LLM calls instead of 3 (1 initial + 1 from max_iter handler, no longer 3 due to extra call bug) - Added test_agent_max_iterations_zero to verify no get_llm_response calls when max_iter=0 - Added test_agent_max_iterations_one_stops_after_two_calls to verify exactly 2 LLM calls with max_iter=1 - Fixed type annotations to properly handle AgentAction | AgentFinish | None Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -202,7 +202,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
formatted_answer = None
|
||||
formatted_answer: AgentAction | AgentFinish | None = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
try:
|
||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||
@@ -214,6 +214,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
@@ -257,9 +259,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
formatted_answer = self._handle_agent_action(
|
||||
formatted_answer, tool_result
|
||||
)
|
||||
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(formatted_answer.text)
|
||||
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(formatted_answer.text)
|
||||
elif isinstance(formatted_answer, AgentFinish):
|
||||
self._invoke_step_callback(formatted_answer)
|
||||
self._append_message(formatted_answer.text)
|
||||
|
||||
except OutputParserError as e: # noqa: PERF203
|
||||
formatted_answer = handle_output_parser_exception(
|
||||
|
||||
@@ -127,7 +127,7 @@ def handle_max_iterations_exceeded(
|
||||
messages: list[LLMMessage],
|
||||
llm: LLM | BaseLLM,
|
||||
callbacks: list[TokenCalcHandler],
|
||||
) -> AgentAction | AgentFinish:
|
||||
) -> AgentFinish:
|
||||
"""Handles the case when the maximum number of iterations is exceeded. Performs one more LLM call to get the final answer.
|
||||
|
||||
Args:
|
||||
@@ -139,7 +139,7 @@ def handle_max_iterations_exceeded(
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
|
||||
Returns:
|
||||
The final formatted answer after exceeding max iterations.
|
||||
The final formatted answer after exceeding max iterations (always AgentFinish).
|
||||
"""
|
||||
printer.print(
|
||||
content="Maximum iterations reached. Requesting final answer.",
|
||||
@@ -168,8 +168,17 @@ def handle_max_iterations_exceeded(
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
|
||||
# Return the formatted answer, regardless of its type
|
||||
return format_answer(answer=answer)
|
||||
# Parse the answer and ensure it's always an AgentFinish
|
||||
parsed_answer = format_answer(answer=answer)
|
||||
|
||||
if isinstance(parsed_answer, AgentAction):
|
||||
return AgentFinish(
|
||||
thought=parsed_answer.thought,
|
||||
output=parsed_answer.text,
|
||||
text=parsed_answer.text,
|
||||
)
|
||||
|
||||
return parsed_answer
|
||||
|
||||
|
||||
def format_message_for_llm(
|
||||
|
||||
@@ -508,7 +508,96 @@ def test_agent_custom_max_iterations():
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
assert call_count > 0
|
||||
assert call_count == 3
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
def test_agent_max_iterations_zero():
|
||||
"""Test that with max_iter=0, get_llm_response is never called but handle_max_iterations_exceeded makes one LLM call."""
|
||||
from unittest.mock import MagicMock
|
||||
from crewai.agents import crew_agent_executor
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
max_iter=0,
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
original_call = agent.llm.call
|
||||
llm_call_count = 0
|
||||
|
||||
def counting_llm_call(*args, **kwargs):
|
||||
nonlocal llm_call_count
|
||||
llm_call_count += 1
|
||||
return original_call(*args, **kwargs)
|
||||
|
||||
agent.llm.call = counting_llm_call
|
||||
|
||||
get_llm_response_call_count = 0
|
||||
original_get_llm_response = crew_agent_executor.get_llm_response
|
||||
|
||||
def counting_get_llm_response(*args, **kwargs):
|
||||
nonlocal get_llm_response_call_count
|
||||
get_llm_response_call_count += 1
|
||||
return original_get_llm_response(*args, **kwargs)
|
||||
|
||||
crew_agent_executor.get_llm_response = counting_get_llm_response
|
||||
|
||||
try:
|
||||
task = Task(
|
||||
description="What is 2+2?",
|
||||
expected_output="The answer",
|
||||
)
|
||||
result = agent.execute_task(task=task)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
assert get_llm_response_call_count == 0
|
||||
assert llm_call_count == 1
|
||||
finally:
|
||||
crew_agent_executor.get_llm_response = original_get_llm_response
|
||||
|
||||
|
||||
def test_agent_max_iterations_one_stops_after_two_calls():
|
||||
"""Test that with max_iter=1, exactly 2 LLM calls are made (initial + max_iter handler)."""
|
||||
@tool
|
||||
def dummy_tool() -> str:
|
||||
"""A dummy tool that returns a value."""
|
||||
return "dummy result"
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
max_iter=1,
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
original_call = agent.llm.call
|
||||
llm_call_count = 0
|
||||
|
||||
def counting_llm_call(*args, **kwargs):
|
||||
nonlocal llm_call_count
|
||||
llm_call_count += 1
|
||||
return original_call(*args, **kwargs)
|
||||
|
||||
agent.llm.call = counting_llm_call
|
||||
|
||||
task = Task(
|
||||
description="Keep using the dummy_tool repeatedly.",
|
||||
expected_output="The final answer",
|
||||
)
|
||||
result = agent.execute_task(
|
||||
task=task,
|
||||
tools=[dummy_tool],
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
assert llm_call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
|
||||
Reference in New Issue
Block a user