mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Refactoring task cache to be a tool (#50)
* Refactoring task cache to be a tool The previous implementation of the task caching system was early exiting the agent executor due to the fact it was returning an AgentFinish object. This now refactors it to use a cache specific tool that is dynamically added and forced into the agent in case of a task execution that was already executed with the same input.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import uuid
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.format_scratchpad import format_log_to_str
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.memory import ConversationSummaryMemory
|
||||
@@ -18,7 +17,12 @@ from pydantic import (
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
from crewai.agents import CacheHandler, CrewAgentOutputParser, ToolsHandler
|
||||
from crewai.agents import (
|
||||
CacheHandler,
|
||||
CrewAgentExecutor,
|
||||
CrewAgentOutputParser,
|
||||
ToolsHandler,
|
||||
)
|
||||
from crewai.prompts import Prompts
|
||||
|
||||
|
||||
@@ -29,7 +33,7 @@ class Agent(BaseModel):
|
||||
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
|
||||
|
||||
Attributes:
|
||||
agent_executor: An instance of the AgentExecutor class.
|
||||
agent_executor: An instance of the CrewAgentExecutor class.
|
||||
role: The role of the agent.
|
||||
goal: The objective of the agent.
|
||||
backstory: The backstory of the agent.
|
||||
@@ -68,8 +72,8 @@ class Agent(BaseModel):
|
||||
tools: List[Any] = Field(
|
||||
default_factory=list, description="Tools at agents disposal"
|
||||
)
|
||||
agent_executor: Optional[InstanceOf[AgentExecutor]] = Field(
|
||||
default=None, description="An instance of the AgentExecutor class."
|
||||
agent_executor: Optional[InstanceOf[CrewAgentExecutor]] = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
tools_handler: Optional[InstanceOf[ToolsHandler]] = Field(
|
||||
default=None, description="An instance of the ToolsHandler class."
|
||||
@@ -127,11 +131,11 @@ class Agent(BaseModel):
|
||||
self.tools_handler = ToolsHandler(cache=self.cache_handler)
|
||||
self.__create_agent_executor()
|
||||
|
||||
def __create_agent_executor(self) -> AgentExecutor:
|
||||
def __create_agent_executor(self) -> CrewAgentExecutor:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Returns:
|
||||
An instance of the AgentExecutor class.
|
||||
An instance of the CrewAgentExecutor class.
|
||||
"""
|
||||
agent_args = {
|
||||
"input": lambda x: x["input"],
|
||||
@@ -170,7 +174,7 @@ class Agent(BaseModel):
|
||||
tools_handler=self.tools_handler, cache=self.cache_handler
|
||||
)
|
||||
)
|
||||
self.agent_executor = AgentExecutor(agent=inner_agent, **executor_args)
|
||||
self.agent_executor = CrewAgentExecutor(agent=inner_agent, **executor_args)
|
||||
|
||||
@staticmethod
|
||||
def __tools_names(tools) -> str:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .cache_handler import CacheHandler
|
||||
from .executor import CrewAgentExecutor
|
||||
from .output_parser import CrewAgentOutputParser
|
||||
from .tools_handler import ToolsHandler
|
||||
|
||||
14
crewai/agents/cache_hit.py
Normal file
14
crewai/agents/cache_hit.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from langchain_core.agents import AgentAction
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from .cache_handler import CacheHandler
|
||||
|
||||
|
||||
class CacheHit(BaseModel):
|
||||
"""Cache Hit Object."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
action: AgentAction = Field(description="Action taken")
|
||||
cache: CacheHandler = Field(description="Cache Handler for the tool")
|
||||
130
crewai/agents/executor.py
Normal file
130
crewai/agents/executor.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.agent import ExceptionTool
|
||||
from langchain.agents.tools import InvalidTool
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain_core.agents import AgentAction, AgentFinish, AgentStep
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from ..tools.cache_tools import CacheTools
|
||||
from .cache_hit import CacheHit
|
||||
|
||||
|
||||
class CrewAgentExecutor(AgentExecutor):
|
||||
def _iter_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Iterator[Union[AgentFinish, AgentAction, AgentStep]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
|
||||
Override this to take control of how the agent makes and acts on choices.
|
||||
"""
|
||||
try:
|
||||
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
||||
|
||||
# Call the LLM to see what to do.
|
||||
output = self.agent.plan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
except OutputParserException as e:
|
||||
if isinstance(self.handle_parsing_errors, bool):
|
||||
raise_error = not self.handle_parsing_errors
|
||||
else:
|
||||
raise_error = False
|
||||
if raise_error:
|
||||
raise ValueError(
|
||||
"An output parsing error occurred. "
|
||||
"In order to pass this error back to the agent and have it try "
|
||||
"again, pass `handle_parsing_errors=True` to the AgentExecutor. "
|
||||
f"This is the error: {str(e)}"
|
||||
)
|
||||
text = str(e)
|
||||
if isinstance(self.handle_parsing_errors, bool):
|
||||
if e.send_to_llm:
|
||||
observation = str(e.observation)
|
||||
text = str(e.llm_output)
|
||||
else:
|
||||
observation = "Invalid or incomplete response"
|
||||
elif isinstance(self.handle_parsing_errors, str):
|
||||
observation = self.handle_parsing_errors
|
||||
elif callable(self.handle_parsing_errors):
|
||||
observation = self.handle_parsing_errors(e)
|
||||
else:
|
||||
raise ValueError("Got unexpected type of `handle_parsing_errors`")
|
||||
output = AgentAction("_Exception", observation, text)
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(output, color="green")
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
observation = ExceptionTool().run(
|
||||
output.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
yield AgentStep(action=output, observation=observation)
|
||||
return
|
||||
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
yield output
|
||||
return
|
||||
|
||||
# Override tool usage to use CacheTools
|
||||
if isinstance(output, CacheHit):
|
||||
cache = output.cache
|
||||
action = output.action
|
||||
tool = CacheTools(cache_handler=cache).tool()
|
||||
output = action.copy()
|
||||
output.tool_input = f"tool:{action.tool}|input:{action.tool_input}"
|
||||
output.tool = tool.name
|
||||
name_to_tool_map[tool.name] = tool
|
||||
color_mapping[tool.name] = color_mapping[action.tool]
|
||||
|
||||
actions: List[AgentAction]
|
||||
if isinstance(output, AgentAction):
|
||||
actions = [output]
|
||||
else:
|
||||
actions = output
|
||||
for agent_action in actions:
|
||||
yield agent_action
|
||||
for agent_action in actions:
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(agent_action, color="green")
|
||||
# Otherwise we lookup the tool
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
return_direct = tool.return_direct
|
||||
color = color_mapping[agent_action.tool]
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = tool.run(
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
observation = InvalidTool().run(
|
||||
{
|
||||
"requested_tool_name": agent_action.tool,
|
||||
"available_tool_names": list(name_to_tool_map.keys()),
|
||||
},
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
yield AgentStep(action=agent_action, observation=observation)
|
||||
@@ -6,6 +6,7 @@ from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
||||
from .cache_handler import CacheHandler
|
||||
from .cache_hit import CacheHit
|
||||
from .tools_handler import ToolsHandler
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
@@ -47,17 +48,13 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser):
|
||||
tools_handler: ToolsHandler
|
||||
cache: CacheHandler
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish, CacheHit]:
|
||||
FINAL_ANSWER_ACTION in text
|
||||
regex = (
|
||||
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
)
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
if action_match:
|
||||
if includes_answer:
|
||||
raise OutputParserException(
|
||||
f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
||||
)
|
||||
action = action_match.group(1).strip()
|
||||
action_input = action_match.group(2)
|
||||
tool_input = action_input.strip(" ")
|
||||
@@ -76,6 +73,7 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser):
|
||||
|
||||
result = self.cache.read(action, tool_input)
|
||||
if result:
|
||||
return AgentFinish({"output": result}, text)
|
||||
action = AgentAction(action, tool_input, text)
|
||||
return CacheHit(action=action, cache=self.cache)
|
||||
|
||||
return super().parse(text)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from ..tools.cache_tools import CacheTools
|
||||
from .cache_handler import CacheHandler
|
||||
|
||||
|
||||
@@ -35,8 +36,9 @@ class ToolsHandler(BaseCallbackHandler):
|
||||
and "Invalid or incomplete response" not in output
|
||||
and "Invalid Format" not in output
|
||||
):
|
||||
self.cache.add(
|
||||
tool=self.last_used_tool["tool"],
|
||||
input=self.last_used_tool["input"],
|
||||
output=output,
|
||||
)
|
||||
if self.last_used_tool["tool"] != CacheTools().name:
|
||||
self.cache.add(
|
||||
tool=self.last_used_tool["tool"],
|
||||
input=self.last_used_tool["input"],
|
||||
output=output,
|
||||
)
|
||||
|
||||
28
crewai/tools/cache_tools.py
Normal file
28
crewai/tools/cache_tools.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from langchain.tools import Tool
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.agents import CacheHandler
|
||||
|
||||
|
||||
class CacheTools(BaseModel):
|
||||
"""Default tools to hit the cache."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
name: str = "Hit Cache"
|
||||
cache_handler: CacheHandler = Field(
|
||||
description="Cache Handler for the crew",
|
||||
default=CacheHandler(),
|
||||
)
|
||||
|
||||
def tool(self):
|
||||
return Tool.from_function(
|
||||
func=self.hit_cache,
|
||||
name=self.name,
|
||||
description="Reads directly from the cache",
|
||||
)
|
||||
|
||||
def hit_cache(self, key):
|
||||
split = key.split("tool:")
|
||||
tool = split[1].split("|input:")[0].strip()
|
||||
tool_input = split[1].split("|input:")[1].strip()
|
||||
return self.cache_handler.read(tool, tool_input)
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
[tool.poetry]
|
||||
name = "crewai"
|
||||
version = "0.1.15"
|
||||
version = "0.1.16"
|
||||
description = "Cutting-edge framework for orchestrating role-playing, autonomous AI agents. By fostering collaborative intelligence, CrewAI empowers agents to work together seamlessly, tackling complex tasks."
|
||||
authors = ["Joao Moura <joaomdmoura@gmail.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
@@ -145,7 +145,7 @@ def test_cache_hitting():
|
||||
def multiplier(numbers) -> float:
|
||||
"""Useful for when you need to multiply two numbers together.
|
||||
The input to this tool should be a comma separated list of numbers of
|
||||
length two, representing the two numbers you want to multiply together.
|
||||
length two and ONLY TWO, representing the two numbers you want to multiply together.
|
||||
For example, `1,2` would be the input if you wanted to multiply 1 by 2."""
|
||||
a, b = numbers.split(",")
|
||||
return int(a) * int(b)
|
||||
@@ -162,15 +162,22 @@ def test_cache_hitting():
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
output = agent.execute_task("What is 2 times 6?")
|
||||
output = agent.execute_task("What is 2 times 6 times 3?")
|
||||
output = agent.execute_task("What is 3 times 3?")
|
||||
assert cache_handler._cache == {"multiplier-2,6": "12", "multiplier-3,3": "9"}
|
||||
assert cache_handler._cache == {
|
||||
"multiplier-12,3": "36",
|
||||
"multiplier-2,6": "12",
|
||||
"multiplier-3,3": "9",
|
||||
}
|
||||
|
||||
output = agent.execute_task("What is 2 times 6 times 3? Return only the number")
|
||||
assert output == "36"
|
||||
|
||||
with patch.object(CacheHandler, "read") as read:
|
||||
read.return_value = "0"
|
||||
output = agent.execute_task("What is 2 times 6?")
|
||||
assert output == "0"
|
||||
read.assert_called_once_with("multiplier", "2,6")
|
||||
read.assert_called_with("multiplier", "2,6")
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@@ -194,4 +201,4 @@ def test_agent_execution_with_specific_tools():
|
||||
)
|
||||
|
||||
output = agent.execute_task(task="What is 3 times 4", tools=[multiplier])
|
||||
assert output == "12"
|
||||
assert output == "3 times 4 is 12."
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -241,12 +241,12 @@ def test_cache_hitting_between_agents():
|
||||
|
||||
tasks = [
|
||||
Task(
|
||||
description="What is 2 tims 6?",
|
||||
description="What is 2 tims 6? Return only the number.",
|
||||
tools=[multiplier],
|
||||
agent=ceo,
|
||||
),
|
||||
Task(
|
||||
description="What is 2 times 6?",
|
||||
description="What is 2 times 6? Return only the number.",
|
||||
tools=[multiplier],
|
||||
agent=researcher,
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user