Revamping tool usage

This commit is contained in:
João Moura
2024-02-10 10:36:34 -08:00
parent 58bb181b48
commit 5a102251cf
22 changed files with 1233 additions and 584 deletions

View File

@@ -1,4 +1,3 @@
from .cache.cache_handler import CacheHandler
from .executor import CrewAgentExecutor
from .output_parser import CrewAgentOutputParser
from .tools_handler import ToolsHandler

View File

@@ -1,2 +1 @@
from .cache_handler import CacheHandler
from .cache_hit import CacheHit

View File

@@ -10,9 +10,7 @@ class CacheHandler:
self._cache = {}
def add(self, tool, input, output):
input = input.strip()
self._cache[f"{tool}-{input}"] = output
def read(self, tool, input) -> Optional[str]:
input = input.strip()
return self._cache.get(f"{tool}-{input}")

View File

@@ -1,18 +0,0 @@
from typing import Any
from pydantic import BaseModel, Field
from .cache_handler import CacheHandler
class CacheHit(BaseModel):
"""Cache Hit Object."""
class Config:
arbitrary_types_allowed = True
# Making it Any instead of AgentAction to avoind
# pydantic v1 vs v2 incompatibility, langchain should
# soon be updated to pydantic v2
action: Any = Field(description="Action taken")
cache: CacheHandler = Field(description="Cache Handler for the tool")

View File

@@ -1,30 +0,0 @@
from langchain_core.exceptions import OutputParserException
from crewai.utilities import I18N
class TaskRepeatedUsageException(OutputParserException):
"""Exception raised when a task is used twice in a roll."""
i18n: I18N = I18N()
error: str = "TaskRepeatedUsageException"
message: str
def __init__(self, i18n: I18N, tool: str, tool_input: str, text: str):
self.i18n = i18n
self.text = text
self.tool = tool
self.tool_input = tool_input
self.message = self.i18n.errors("task_repeated_usage").format(
tool=tool, tool_input=tool_input
)
super().__init__(
error=self.error,
observation=self.message,
send_to_llm=True,
llm_output=self.text,
)
def __str__(self):
return self.message

View File

@@ -10,16 +10,19 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.pydantic_v1 import root_validator
from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping
from pydantic import InstanceOf
from crewai.agents.cache.cache_hit import CacheHit
from crewai.tools.cache_tools import CacheTools
from crewai.agents.tools_handler import ToolsHandler
from crewai.tools.tool_usage import ToolUsage
from crewai.utilities import I18N
class CrewAgentExecutor(AgentExecutor):
i18n: I18N = I18N()
llm: Any = None
iterations: int = 0
request_within_rpm_limit: Any = None
tools_handler: InstanceOf[ToolsHandler] = None
max_iterations: Optional[int] = 15
force_answer_max_iterations: Optional[int] = None
step_callback: Optional[Any] = None
@@ -32,11 +35,6 @@ class CrewAgentExecutor(AgentExecutor):
def _should_force_answer(self) -> bool:
return True if self.iterations == self.force_answer_max_iterations else False
def _force_answer(self, output: AgentAction):
return AgentStep(
action=output, observation=self.i18n.errors("force_final_answer")
)
def _call(
self,
inputs: Dict[str, str],
@@ -110,16 +108,17 @@ class CrewAgentExecutor(AgentExecutor):
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
if self._should_force_answer():
if isinstance(output, AgentAction) or isinstance(output, AgentFinish):
output = output
elif isinstance(output, CacheHit):
output = output.action
else:
raise ValueError(
f"Unexpected output type from agent: {type(output)}"
)
yield self._force_answer(output)
yield AgentStep(
action=output, observation=self.i18n.errors("force_final_answer")
)
return
except OutputParserException as e:
@@ -160,7 +159,9 @@ class CrewAgentExecutor(AgentExecutor):
)
if self._should_force_answer():
yield self._force_answer(output)
yield AgentStep(
action=output, observation=self.i18n.errors("force_final_answer")
)
return
yield AgentStep(action=output, observation=observation)
@@ -171,17 +172,6 @@ class CrewAgentExecutor(AgentExecutor):
yield output
return
# Override tool usage to use CacheTools
if isinstance(output, CacheHit):
cache = output.cache
action = output.action
tool = CacheTools(cache_handler=cache).tool()
output = action.copy()
output.tool_input = f"tool:{action.tool}|input:{action.tool_input}"
output.tool = tool.name
name_to_tool_map[tool.name] = tool
color_mapping[tool.name] = color_mapping[action.tool]
actions: List[AgentAction]
actions = [output] if isinstance(output, AgentAction) else output
yield from actions
@@ -192,18 +182,13 @@ class CrewAgentExecutor(AgentExecutor):
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = tool.run(
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
observation = ToolUsage(
tools_handler=self.tools_handler, tools=self.tools, llm=self.llm
).use(agent_action.log)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = InvalidTool().run(

View File

@@ -1,79 +0,0 @@
import re
from typing import Union
from langchain.agents.output_parsers import ReActSingleInputOutputParser
from langchain_core.agents import AgentAction, AgentFinish
from crewai.agents.cache import CacheHandler, CacheHit
from crewai.agents.exceptions import TaskRepeatedUsageException
from crewai.agents.tools_handler import ToolsHandler
from crewai.utilities import I18N
FINAL_ANSWER_ACTION = "Final Answer:"
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
"Parsing LLM output produced both a final answer and a parse-able action:"
)
class CrewAgentOutputParser(ReActSingleInputOutputParser):
"""Parses ReAct-style LLM calls that have a single tool input.
Expects output to be in one of two formats.
If the output signals that an action should be taken,
should be in the below format. This will result in an AgentAction
being returned.
```
Thought: agent thought here
Action: search
Action Input: what is the temperature in SF?
```
If the output signals that a final answer should be given,
should be in the below format. This will result in an AgentFinish
being returned.
```
Thought: agent thought here
Final Answer: The temperature is 100 degrees
```
It also prevents tools from being reused in a roll.
"""
class Config:
arbitrary_types_allowed = True
tools_handler: ToolsHandler
cache: CacheHandler
i18n: I18N
def parse(self, text: str) -> Union[AgentAction, AgentFinish, CacheHit]:
regex = (
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
)
if action_match := re.search(regex, text, re.DOTALL):
action = action_match.group(1).strip()
action_input = action_match.group(2)
tool_input = action_input.strip(" ")
tool_input = tool_input.strip('"')
if last_tool_usage := self.tools_handler.last_used_tool:
usage = {
"tool": action,
"input": tool_input,
}
if usage == last_tool_usage:
raise TaskRepeatedUsageException(
text=text,
tool=action,
tool_input=tool_input,
i18n=self.i18n,
)
if self.cache.read(action, tool_input):
action = AgentAction(action, tool_input, text)
return CacheHit(action=action, cache=self.cache)
return super().parse(text)

View File

@@ -1,44 +1,30 @@
from typing import Any, Dict
from langchain.callbacks.base import BaseCallbackHandler
from typing import Any
from ..tools.cache_tools import CacheTools
from ..tools.tool_calling import ToolCalling
from .cache.cache_handler import CacheHandler
class ToolsHandler(BaseCallbackHandler):
class ToolsHandler:
"""Callback handler for tool usage."""
last_used_tool: Dict[str, Any] = {}
last_used_tool: ToolCalling = {}
cache: CacheHandler
def __init__(self, cache: CacheHandler, **kwargs: Any):
def __init__(self, cache: CacheHandler):
"""Initialize the callback handler."""
self.cache = cache
super().__init__(**kwargs)
self.last_used_tool = {}
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
def on_tool_start(self, calling: ToolCalling) -> Any:
"""Run when tool starts running."""
name = serialized.get("name")
if name not in ["invalid_tool", "_Exception"]:
tools_usage = {
"tool": name,
"input": input_str,
}
self.last_used_tool = tools_usage
self.last_used_tool = calling
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
def on_tool_end(self, calling: ToolCalling, output: str) -> Any:
"""Run when tool ends running."""
if (
"is not a valid tool" not in output
and "Invalid or incomplete response" not in output
and "Invalid Format" not in output
):
if self.last_used_tool["tool"] != CacheTools().name:
self.cache.add(
tool=self.last_used_tool["tool"],
input=self.last_used_tool["input"],
output=output,
)
if self.last_used_tool.function_name != CacheTools().name:
self.cache.add(
tool=calling.function_name,
input=calling.arguments,
output=output,
)