mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Revamp max iteration Logic (#111)
This now will allow to add a max_inter option to agents while also making sure to force the agent to give it's best final answer before running out of it's max_inter.
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
import time
|
||||
from textwrap import dedent
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.agent import ExceptionTool
|
||||
@@ -6,13 +8,85 @@ from langchain.agents.tools import InvalidTool
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain_core.agents import AgentAction, AgentFinish, AgentStep
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.input import get_color_mapping
|
||||
|
||||
from ..tools.cache_tools import CacheTools
|
||||
from .cache.cache_hit import CacheHit
|
||||
|
||||
|
||||
class CrewAgentExecutor(AgentExecutor):
|
||||
iterations: int = 0
|
||||
max_iterations: Optional[int] = 15
|
||||
force_answer_max_iterations: Optional[int] = None
|
||||
|
||||
@root_validator()
|
||||
def set_force_answer_max_iterations(cls, values: Dict) -> Dict:
|
||||
values["force_answer_max_iterations"] = values["max_iterations"] - 2
|
||||
return values
|
||||
|
||||
def _should_force_answer(self) -> bool:
|
||||
return True if self.iterations == self.force_answer_max_iterations else False
|
||||
|
||||
def _force_answer(self, output: AgentAction):
|
||||
return AgentStep(
|
||||
action=output,
|
||||
observation=dedent(
|
||||
"""\
|
||||
I've used too many tools for this task.
|
||||
I'm going to give you my absolute BEST Final answer now and
|
||||
not use any more tools."""
|
||||
),
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run text through and get agent response."""
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
# We construct a mapping from each tool to a color, used for logging.
|
||||
color_mapping = get_color_mapping(
|
||||
[tool.name for tool in self.tools], excluded_colors=["green", "red"]
|
||||
)
|
||||
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
||||
# Let's start tracking the number of iterations and time elapsed
|
||||
self.iterations = 0
|
||||
time_elapsed = 0.0
|
||||
start_time = time.time()
|
||||
# We now enter the agent loop (until it returns something).
|
||||
while self._should_continue(self.iterations, time_elapsed):
|
||||
next_step_output = self._take_next_step(
|
||||
name_to_tool_map,
|
||||
color_mapping,
|
||||
inputs,
|
||||
intermediate_steps,
|
||||
run_manager=run_manager,
|
||||
)
|
||||
if isinstance(next_step_output, AgentFinish):
|
||||
return self._return(
|
||||
next_step_output, intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
|
||||
intermediate_steps.extend(next_step_output)
|
||||
if len(next_step_output) == 1:
|
||||
next_step_action = next_step_output[0]
|
||||
# See if tool should return directly
|
||||
tool_return = self._get_tool_return(next_step_action)
|
||||
if tool_return is not None:
|
||||
return self._return(
|
||||
tool_return, intermediate_steps, run_manager=run_manager
|
||||
)
|
||||
self.iterations += 1
|
||||
time_elapsed = time.time() - start_time
|
||||
output = self.agent.return_stopped_response(
|
||||
self.early_stopping_method, intermediate_steps, **inputs
|
||||
)
|
||||
return self._return(output, intermediate_steps, run_manager=run_manager)
|
||||
|
||||
def _iter_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
@@ -34,6 +108,14 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
if self._should_force_answer():
|
||||
if isinstance(output, AgentAction):
|
||||
output = output
|
||||
else:
|
||||
output = output.action
|
||||
yield self._force_answer(output)
|
||||
return
|
||||
|
||||
except OutputParserException as e:
|
||||
if isinstance(self.handle_parsing_errors, bool):
|
||||
raise_error = not self.handle_parsing_errors
|
||||
@@ -70,6 +152,11 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
|
||||
if self._should_force_answer():
|
||||
yield self._force_answer(output)
|
||||
return
|
||||
|
||||
yield AgentStep(action=output, observation=observation)
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user