Revamp max iteration Logic (#111)

This now will allow to add a max_inter option to agents while also making sure to force the agent to give it's best final answer before running out of it's max_inter.
This commit is contained in:
João Moura
2024-01-11 12:32:54 -03:00
committed by GitHub
parent 8cc51d5e9e
commit ea7759b322
6 changed files with 1035 additions and 17 deletions

View File

@@ -1,4 +1,6 @@
from typing import Dict, Iterator, List, Optional, Tuple, Union
import time
from textwrap import dedent
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from langchain.agents import AgentExecutor
from langchain.agents.agent import ExceptionTool
@@ -6,13 +8,85 @@ from langchain.agents.tools import InvalidTool
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain_core.agents import AgentAction, AgentFinish, AgentStep
from langchain_core.exceptions import OutputParserException
from langchain_core.pydantic_v1 import root_validator
from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping
from ..tools.cache_tools import CacheTools
from .cache.cache_hit import CacheHit
class CrewAgentExecutor(AgentExecutor):
iterations: int = 0
max_iterations: Optional[int] = 15
force_answer_max_iterations: Optional[int] = None
@root_validator()
def set_force_answer_max_iterations(cls, values: Dict) -> Dict:
values["force_answer_max_iterations"] = values["max_iterations"] - 2
return values
def _should_force_answer(self) -> bool:
return True if self.iterations == self.force_answer_max_iterations else False
def _force_answer(self, output: AgentAction):
return AgentStep(
action=output,
observation=dedent(
"""\
I've used too many tools for this task.
I'm going to give you my absolute BEST Final answer now and
not use any more tools."""
),
)
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run text through and get agent response."""
# Construct a mapping of tool name to tool for easy lookup
name_to_tool_map = {tool.name: tool for tool in self.tools}
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green", "red"]
)
intermediate_steps: List[Tuple[AgentAction, str]] = []
# Let's start tracking the number of iterations and time elapsed
self.iterations = 0
time_elapsed = 0.0
start_time = time.time()
# We now enter the agent loop (until it returns something).
while self._should_continue(self.iterations, time_elapsed):
next_step_output = self._take_next_step(
name_to_tool_map,
color_mapping,
inputs,
intermediate_steps,
run_manager=run_manager,
)
if isinstance(next_step_output, AgentFinish):
return self._return(
next_step_output, intermediate_steps, run_manager=run_manager
)
intermediate_steps.extend(next_step_output)
if len(next_step_output) == 1:
next_step_action = next_step_output[0]
# See if tool should return directly
tool_return = self._get_tool_return(next_step_action)
if tool_return is not None:
return self._return(
tool_return, intermediate_steps, run_manager=run_manager
)
self.iterations += 1
time_elapsed = time.time() - start_time
output = self.agent.return_stopped_response(
self.early_stopping_method, intermediate_steps, **inputs
)
return self._return(output, intermediate_steps, run_manager=run_manager)
def _iter_next_step(
self,
name_to_tool_map: Dict[str, BaseTool],
@@ -34,6 +108,14 @@ class CrewAgentExecutor(AgentExecutor):
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
if self._should_force_answer():
if isinstance(output, AgentAction):
output = output
else:
output = output.action
yield self._force_answer(output)
return
except OutputParserException as e:
if isinstance(self.handle_parsing_errors, bool):
raise_error = not self.handle_parsing_errors
@@ -70,6 +152,11 @@ class CrewAgentExecutor(AgentExecutor):
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
if self._should_force_answer():
yield self._force_answer(output)
return
yield AgentStep(action=output, observation=observation)
return