Adding tool caching a loop execution prevention. (#25)

* Adding tool caching a loop execution prevention.

This adds some guardrails, to both prevent the same tool to be used
consecutively and also caching tool's results across the entire crew
so it cuts down execution time and eventual LLM calls.

This plays a huge role for smaller opensource models that usually fall
into those behaviors patterns.

It also includes some smaller improvements around the tool prompt and
agent tools, all with the same intention of guiding models into
better conform with agent instructions.
This commit is contained in:
João Moura
2023-12-29 22:35:23 -03:00
committed by GitHub
parent 5cc230263c
commit af9e749edb
14 changed files with 3046 additions and 54 deletions

View File

@@ -1,14 +1,14 @@
"""Generic agent."""
from typing import Any, List, Optional
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import ReActSingleInputOutputParser
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationSummaryMemory
from langchain.tools.render import render_text_description
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from langchain_core.runnables.config import RunnableConfig
from pydantic import BaseModel, Field, InstanceOf, model_validator
from .agents import CacheHandler, CrewAgentOutputParser, ToolsHandler
from .prompts import Prompts
@@ -29,9 +29,9 @@ class Agent(BaseModel):
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
"""
agent_executor: Optional[InstanceOf[AgentExecutor]] = Field(
default=None, description="An instance of the AgentExecutor class."
)
class Config:
arbitrary_types_allowed = True
role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent")
@@ -54,15 +54,59 @@ class Agent(BaseModel):
tools: List[Any] = Field(
default_factory=list, description="Tools at agents disposal"
)
_task_calls: List[Any] = PrivateAttr()
agent_executor: Optional[InstanceOf[AgentExecutor]] = Field(
default=None, description="An instance of the AgentExecutor class."
)
tools_handler: Optional[InstanceOf[ToolsHandler]] = Field(
default=None, description="An instance of the ToolsHandler class."
)
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
default=CacheHandler(), description="An instance of the CacheHandler class."
)
@model_validator(mode="after")
def check_agent_executor(self) -> "Agent":
if not self.agent_executor:
self.agent_executor = self._create_agent_executor()
self.set_cache_handler(self.cache_handler)
return self
def _create_agent_executor(self) -> AgentExecutor:
def execute_task(
self, task: str, context: str = None, tools: List[Any] = None
) -> str:
"""Execute a task with the agent.
Args:
task: Task to execute.
context: Context to execute the task in.
tools: Tools to use for the task.
Returns:
Output of the agent
"""
if context:
task = "\n".join(
[task, "\nThis is the context you are working with:", context]
)
tools = tools or self.tools
self.agent_executor.tools = tools
return self.agent_executor.invoke(
{
"input": task,
"tool_names": self.__tools_names(tools),
"tools": render_text_description(tools),
},
RunnableConfig(callbacks=[self.tools_handler]),
)["output"]
def set_cache_handler(self, cache_handler) -> None:
print(f"cache_handler: {cache_handler}")
self.cache_handler = cache_handler
self.tools_handler = ToolsHandler(cache=self.cache_handler)
self.__create_agent_executor()
def __create_agent_executor(self) -> AgentExecutor:
"""Create an agent executor for the agent.
Returns:
@@ -98,38 +142,14 @@ class Agent(BaseModel):
bind = self.llm.bind(stop=["\nObservation"])
inner_agent = (
agent_args | execution_prompt | bind | ReActSingleInputOutputParser()
)
return AgentExecutor(agent=inner_agent, **executor_args)
def execute_task(
self, task: str, context: str = None, tools: List[Any] = None
) -> str:
"""Execute a task with the agent.
Args:
task: Task to execute.
context: Context to execute the task in.
tools: Tools to use for the task.
Returns:
Output of the agent
"""
if context:
task = "\n".join(
[task, "\nThis is the context you are working with:", context]
agent_args
| execution_prompt
| bind
| CrewAgentOutputParser(
tools_handler=self.tools_handler, cache=self.cache_handler
)
tools = tools or self.tools
self.agent_executor.tools = tools
return self.agent_executor.invoke(
{
"input": task,
"tool_names": self.__tools_names(tools),
"tools": render_text_description(tools),
}
)["output"]
)
self.agent_executor = AgentExecutor(agent=inner_agent, **executor_args)
@staticmethod
def __tools_names(tools) -> str:

View File

@@ -0,0 +1,3 @@
from .cache_handler import CacheHandler
from .output_parser import CrewAgentOutputParser
from .tools_handler import ToolsHandler

View File

@@ -0,0 +1,20 @@
from typing import Optional
from pydantic import PrivateAttr
class CacheHandler:
"""Callback handler for tool usage."""
_cache: PrivateAttr = {}
def __init__(self):
self._cache = {}
def add(self, tool, input, output):
input = input.strip()
self._cache[f"{tool}-{input}"] = output
def read(self, tool, input) -> Optional[str]:
input = input.strip()
return self._cache.get(f"{tool}-{input}")

View File

@@ -0,0 +1,81 @@
import re
from typing import Union
from langchain.agents.output_parsers import ReActSingleInputOutputParser
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from .cache_handler import CacheHandler
from .tools_handler import ToolsHandler
FINAL_ANSWER_ACTION = "Final Answer:"
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
"Parsing LLM output produced both a final answer and a parse-able action:"
)
class CrewAgentOutputParser(ReActSingleInputOutputParser):
"""Parses ReAct-style LLM calls that have a single tool input.
Expects output to be in one of two formats.
If the output signals that an action should be taken,
should be in the below format. This will result in an AgentAction
being returned.
```
Thought: agent thought here
Action: search
Action Input: what is the temperature in SF?
```
If the output signals that a final answer should be given,
should be in the below format. This will result in an AgentFinish
being returned.
```
Thought: agent thought here
Final Answer: The temperature is 100 degrees
```
It also prevents tools from being reused in a roll.
"""
class Config:
arbitrary_types_allowed = True
tools_handler: ToolsHandler
cache: CacheHandler
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
includes_answer = FINAL_ANSWER_ACTION in text
regex = (
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
)
action_match = re.search(regex, text, re.DOTALL)
if action_match:
if includes_answer:
raise OutputParserException(
f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
)
action = action_match.group(1).strip()
action_input = action_match.group(2)
tool_input = action_input.strip(" ")
tool_input = tool_input.strip('"')
last_tool_usage = self.tools_handler.last_used_tool
if last_tool_usage:
usage = {
"tool": action,
"input": tool_input,
}
if usage == last_tool_usage:
raise OutputParserException(
f"""\nI just used the {action} tool with input {tool_input}. So I already knwo the result of that."""
)
result = self.cache.read(action, tool_input)
if result:
return AgentFinish({"output": result}, text)
return super().parse(text)

View File

@@ -0,0 +1,42 @@
from typing import Any, Dict
from langchain.callbacks.base import BaseCallbackHandler
from .cache_handler import CacheHandler
class ToolsHandler(BaseCallbackHandler):
"""Callback handler for tool usage."""
last_used_tool: Dict[str, Any] = {}
cache: CacheHandler = None
def __init__(self, cache: CacheHandler = None, **kwargs: Any):
"""Initialize the callback handler."""
self.cache = cache
super().__init__(**kwargs)
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> Any:
"""Run when tool starts running."""
name = serialized.get("name")
if name not in ["invalid_tool", "_Exception"]:
tools_usage = {
"tool": name,
"input": input_str,
}
self.last_used_tool = tools_usage
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
if (
"is not a valid tool" not in output
and "Invalid or incomplete response" not in output
and "Invalid Format" not in output
):
self.cache.add(
tool=self.last_used_tool["tool"],
input=self.last_used_tool["input"],
output=output,
)

View File

@@ -1,10 +1,18 @@
import json
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field, Json, field_validator, model_validator
from pydantic import (
BaseModel,
Field,
InstanceOf,
Json,
field_validator,
model_validator,
)
from pydantic_core import PydanticCustomError
from .agent import Agent
from .agents import CacheHandler
from .process import Process
from .task import Task
from .tools.agent_tools import AgentTools
@@ -13,6 +21,9 @@ from .tools.agent_tools import AgentTools
class Crew(BaseModel):
"""Class that represents a group of agents, how they should work together and their tasks."""
class Config:
arbitrary_types_allowed = True
tasks: List[Task] = Field(description="List of tasks", default_factory=list)
agents: List[Agent] = Field(
description="List of agents in this crew.", default_factory=list
@@ -26,6 +37,9 @@ class Crew(BaseModel):
config: Optional[Union[Json, Dict[str, Any]]] = Field(
description="Configuration of the crew.", default=None
)
cache_handler: Optional[InstanceOf[CacheHandler]] = Field(
default=CacheHandler(), description="An instance of the CacheHandler class."
)
@classmethod
@field_validator("config", mode="before")
@@ -58,6 +72,10 @@ class Crew(BaseModel):
tasks.append(Task(**task, agent=task_agent))
self.tasks = tasks
if self.agents:
for agent in self.agents:
agent.set_cache_handler(self.cache_handler)
return self
def kickoff(self) -> str:
@@ -66,6 +84,9 @@ class Crew(BaseModel):
Returns:
Output of the crew for each task.
"""
for agent in self.agents:
agent.cache_handler = self.cache_handler
if self.process == Process.sequential:
return self.__sequential_loop()

View File

@@ -48,7 +48,7 @@ class Prompts(BaseModel):
```
Thought: Do I need to use a tool? Yes
Action: the action to take, should be one of [{tool_names}]
Action: the action to take, should be one of [{tool_names}], just the name.
Action Input: the input to the action
Observation: the result of the action
```

View File

@@ -16,7 +16,7 @@ class AgentTools(BaseModel):
return [
Tool.from_function(
func=self.delegate_work,
name="Delegate Work to Co-Worker",
name="Delegate work to co-worker",
description=dedent(
f"""Useful to delegate a specific task to one of the
following co-workers: [{', '.join([agent.role for agent in self.agents])}].
@@ -28,7 +28,7 @@ class AgentTools(BaseModel):
),
Tool.from_function(
func=self.ask_question,
name="Ask Question to Co-Worker",
name="Ask question to co-worker",
description=dedent(
f"""Useful to ask a question, opinion or take from on
of the following co-workers: [{', '.join([agent.role for agent in self.agents])}].
@@ -53,10 +53,10 @@ class AgentTools(BaseModel):
try:
agent, task, information = command.split("|")
except ValueError:
return "Error executing tool. Missing exact 3 pipe (|) separated values. For example, `coworker|task|information`."
return "\nError executing tool. Missing exact 3 pipe (|) separated values. For example, `coworker|task|information`."
if not agent or not task or not information:
return "Error executing tool. Missing exact 3 pipe (|) separated values. For example, `coworker|question|information`."
return "\nError executing tool. Missing exact 3 pipe (|) separated values. For example, `coworker|question|information`."
agent = [
available_agent
@@ -65,9 +65,7 @@ class AgentTools(BaseModel):
]
if len(agent) == 0:
return (
"Error executing tool. Co-worker not found, double check the co-worker."
)
return f"\nError executing tool. Co-worker mentioned on the Action Input not found, it must to be one of the following options: {', '.join([agent.role for agent in self.agents])}."
agent = agent[0]
result = agent.execute_task(task, information)