Refactoring task cache to be a tool (#50)

* Refactoring task cache to be a tool

The previous implementation of the task caching system was early exiting
the agent executor due to the fact it was returning an AgentFinish object.

This now refactors it to use a cache specific tool that is dynamically
added and forced into the agent in case of a task execution that was
already executed with the same input.
This commit is contained in:
João Moura
2024-01-04 21:29:42 -03:00
committed by GitHub
parent fe6bef0af1
commit 6b054651a7
12 changed files with 1851 additions and 576 deletions

View File

@@ -1,7 +1,6 @@
import uuid
from typing import Any, List, Optional
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationSummaryMemory
@@ -18,7 +17,12 @@ from pydantic import (
)
from pydantic_core import PydanticCustomError
from crewai.agents import CacheHandler, CrewAgentOutputParser, ToolsHandler
from crewai.agents import (
CacheHandler,
CrewAgentExecutor,
CrewAgentOutputParser,
ToolsHandler,
)
from crewai.prompts import Prompts
@@ -29,7 +33,7 @@ class Agent(BaseModel):
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
Attributes:
agent_executor: An instance of the AgentExecutor class.
agent_executor: An instance of the CrewAgentExecutor class.
role: The role of the agent.
goal: The objective of the agent.
backstory: The backstory of the agent.
@@ -68,8 +72,8 @@ class Agent(BaseModel):
tools: List[Any] = Field(
default_factory=list, description="Tools at agents disposal"
)
agent_executor: Optional[InstanceOf[AgentExecutor]] = Field(
default=None, description="An instance of the AgentExecutor class."
agent_executor: Optional[InstanceOf[CrewAgentExecutor]] = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
tools_handler: Optional[InstanceOf[ToolsHandler]] = Field(
default=None, description="An instance of the ToolsHandler class."
@@ -127,11 +131,11 @@ class Agent(BaseModel):
self.tools_handler = ToolsHandler(cache=self.cache_handler)
self.__create_agent_executor()
def __create_agent_executor(self) -> AgentExecutor:
def __create_agent_executor(self) -> CrewAgentExecutor:
"""Create an agent executor for the agent.
Returns:
An instance of the AgentExecutor class.
An instance of the CrewAgentExecutor class.
"""
agent_args = {
"input": lambda x: x["input"],
@@ -170,7 +174,7 @@ class Agent(BaseModel):
tools_handler=self.tools_handler, cache=self.cache_handler
)
)
self.agent_executor = AgentExecutor(agent=inner_agent, **executor_args)
self.agent_executor = CrewAgentExecutor(agent=inner_agent, **executor_args)
@staticmethod
def __tools_names(tools) -> str:

View File

@@ -1,3 +1,4 @@
from .cache_handler import CacheHandler
from .executor import CrewAgentExecutor
from .output_parser import CrewAgentOutputParser
from .tools_handler import ToolsHandler

View File

@@ -0,0 +1,14 @@
from langchain_core.agents import AgentAction
from pydantic.v1 import BaseModel, Field
from .cache_handler import CacheHandler
class CacheHit(BaseModel):
"""Cache Hit Object."""
class Config:
arbitrary_types_allowed = True
action: AgentAction = Field(description="Action taken")
cache: CacheHandler = Field(description="Cache Handler for the tool")

130
crewai/agents/executor.py Normal file
View File

@@ -0,0 +1,130 @@
from typing import Dict, Iterator, List, Optional, Tuple, Union
from langchain.agents import AgentExecutor
from langchain.agents.agent import ExceptionTool
from langchain.agents.tools import InvalidTool
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain_core.agents import AgentAction, AgentFinish, AgentStep
from langchain_core.exceptions import OutputParserException
from langchain_core.tools import BaseTool
from ..tools.cache_tools import CacheTools
from .cache_hit import CacheHit
class CrewAgentExecutor(AgentExecutor):
def _iter_next_step(
self,
name_to_tool_map: Dict[str, BaseTool],
color_mapping: Dict[str, str],
inputs: Dict[str, str],
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Iterator[Union[AgentFinish, AgentAction, AgentStep]]:
"""Take a single step in the thought-action-observation loop.
Override this to take control of how the agent makes and acts on choices.
"""
try:
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
# Call the LLM to see what to do.
output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
except OutputParserException as e:
if isinstance(self.handle_parsing_errors, bool):
raise_error = not self.handle_parsing_errors
else:
raise_error = False
if raise_error:
raise ValueError(
"An output parsing error occurred. "
"In order to pass this error back to the agent and have it try "
"again, pass `handle_parsing_errors=True` to the AgentExecutor. "
f"This is the error: {str(e)}"
)
text = str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
observation = str(e.observation)
text = str(e.llm_output)
else:
observation = "Invalid or incomplete response"
elif isinstance(self.handle_parsing_errors, str):
observation = self.handle_parsing_errors
elif callable(self.handle_parsing_errors):
observation = self.handle_parsing_errors(e)
else:
raise ValueError("Got unexpected type of `handle_parsing_errors`")
output = AgentAction("_Exception", observation, text)
if run_manager:
run_manager.on_agent_action(output, color="green")
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool_input,
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
yield AgentStep(action=output, observation=observation)
return
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
yield output
return
# Override tool usage to use CacheTools
if isinstance(output, CacheHit):
cache = output.cache
action = output.action
tool = CacheTools(cache_handler=cache).tool()
output = action.copy()
output.tool_input = f"tool:{action.tool}|input:{action.tool_input}"
output.tool = tool.name
name_to_tool_map[tool.name] = tool
color_mapping[tool.name] = color_mapping[action.tool]
actions: List[AgentAction]
if isinstance(output, AgentAction):
actions = [output]
else:
actions = output
for agent_action in actions:
yield agent_action
for agent_action in actions:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
# Otherwise we lookup the tool
if agent_action.tool in name_to_tool_map:
tool = name_to_tool_map[agent_action.tool]
return_direct = tool.return_direct
color = color_mapping[agent_action.tool]
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
if return_direct:
tool_run_kwargs["llm_prefix"] = ""
# We then call the tool on the tool input to get an observation
observation = tool.run(
agent_action.tool_input,
verbose=self.verbose,
color=color,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
else:
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = InvalidTool().run(
{
"requested_tool_name": agent_action.tool,
"available_tool_names": list(name_to_tool_map.keys()),
},
verbose=self.verbose,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
yield AgentStep(action=agent_action, observation=observation)

View File

@@ -6,6 +6,7 @@ from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from .cache_handler import CacheHandler
from .cache_hit import CacheHit
from .tools_handler import ToolsHandler
FINAL_ANSWER_ACTION = "Final Answer:"
@@ -47,17 +48,13 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser):
tools_handler: ToolsHandler
cache: CacheHandler
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
includes_answer = FINAL_ANSWER_ACTION in text
def parse(self, text: str) -> Union[AgentAction, AgentFinish, CacheHit]:
FINAL_ANSWER_ACTION in text
regex = (
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
)
action_match = re.search(regex, text, re.DOTALL)
if action_match:
if includes_answer:
raise OutputParserException(
f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
)
action = action_match.group(1).strip()
action_input = action_match.group(2)
tool_input = action_input.strip(" ")
@@ -76,6 +73,7 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser):
result = self.cache.read(action, tool_input)
if result:
return AgentFinish({"output": result}, text)
action = AgentAction(action, tool_input, text)
return CacheHit(action=action, cache=self.cache)
return super().parse(text)

View File

@@ -2,6 +2,7 @@ from typing import Any, Dict
from langchain.callbacks.base import BaseCallbackHandler
from ..tools.cache_tools import CacheTools
from .cache_handler import CacheHandler
@@ -35,8 +36,9 @@ class ToolsHandler(BaseCallbackHandler):
and "Invalid or incomplete response" not in output
and "Invalid Format" not in output
):
self.cache.add(
tool=self.last_used_tool["tool"],
input=self.last_used_tool["input"],
output=output,
)
if self.last_used_tool["tool"] != CacheTools().name:
self.cache.add(
tool=self.last_used_tool["tool"],
input=self.last_used_tool["input"],
output=output,
)

View File

@@ -0,0 +1,28 @@
from langchain.tools import Tool
from pydantic import BaseModel, ConfigDict, Field
from crewai.agents import CacheHandler
class CacheTools(BaseModel):
"""Default tools to hit the cache."""
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str = "Hit Cache"
cache_handler: CacheHandler = Field(
description="Cache Handler for the crew",
default=CacheHandler(),
)
def tool(self):
return Tool.from_function(
func=self.hit_cache,
name=self.name,
description="Reads directly from the cache",
)
def hit_cache(self, key):
split = key.split("tool:")
tool = split[1].split("|input:")[0].strip()
tool_input = split[1].split("|input:")[1].strip()
return self.cache_handler.read(tool, tool_input)

View File

@@ -1,7 +1,7 @@
[tool.poetry]
name = "crewai"
version = "0.1.15"
version = "0.1.16"
description = "Cutting-edge framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks."
authors = ["Joao Moura <joaomdmoura@gmail.com>"]
readme = "README.md"

View File

@@ -145,7 +145,7 @@ def test_cache_hitting():
def multiplier(numbers) -> float:
"""Useful for when you need to multiply two numbers together.
The input to this tool should be a comma separated list of numbers of
length two, representing the two numbers you want to multiply together.
length two and ONLY TWO, representing the two numbers you want to multiply together.
For example, `1,2` would be the input if you wanted to multiply 1 by 2."""
a, b = numbers.split(",")
return int(a) * int(b)
@@ -162,15 +162,22 @@ def test_cache_hitting():
verbose=True,
)
output = agent.execute_task("What is 2 times 6?")
output = agent.execute_task("What is 2 times 6 times 3?")
output = agent.execute_task("What is 3 times 3?")
assert cache_handler._cache == {"multiplier-2,6": "12", "multiplier-3,3": "9"}
assert cache_handler._cache == {
"multiplier-12,3": "36",
"multiplier-2,6": "12",
"multiplier-3,3": "9",
}
output = agent.execute_task("What is 2 times 6 times 3? Return only the number")
assert output == "36"
with patch.object(CacheHandler, "read") as read:
read.return_value = "0"
output = agent.execute_task("What is 2 times 6?")
assert output == "0"
read.assert_called_once_with("multiplier", "2,6")
read.assert_called_with("multiplier", "2,6")
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -194,4 +201,4 @@ def test_agent_execution_with_specific_tools():
)
output = agent.execute_task(task="What is 3 times 4", tools=[multiplier])
assert output == "12"
assert output == "3 times 4 is 12."

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -241,12 +241,12 @@ def test_cache_hitting_between_agents():
tasks = [
Task(
description="What is 2 tims 6?",
description="What is 2 tims 6? Return only the number.",
tools=[multiplier],
agent=ceo,
),
Task(
description="What is 2 times 6?",
description="What is 2 times 6? Return only the number.",
tools=[multiplier],
agent=researcher,
),