mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Refactoring task cache to be a tool (#50)
* Refactoring task cache to be a tool The previous implementation of the task caching system was early exiting the agent executor due to the fact it was returning an AgentFinish object. This now refactors it to use a cache specific tool that is dynamically added and forced into the agent in case of a task execution that was already executed with the same input.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import uuid
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.format_scratchpad import format_log_to_str
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.memory import ConversationSummaryMemory
|
||||
@@ -18,7 +17,12 @@ from pydantic import (
|
||||
)
|
||||
from pydantic_core import PydanticCustomError
|
||||
|
||||
from crewai.agents import CacheHandler, CrewAgentOutputParser, ToolsHandler
|
||||
from crewai.agents import (
|
||||
CacheHandler,
|
||||
CrewAgentExecutor,
|
||||
CrewAgentOutputParser,
|
||||
ToolsHandler,
|
||||
)
|
||||
from crewai.prompts import Prompts
|
||||
|
||||
|
||||
@@ -29,7 +33,7 @@ class Agent(BaseModel):
|
||||
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
|
||||
|
||||
Attributes:
|
||||
agent_executor: An instance of the AgentExecutor class.
|
||||
agent_executor: An instance of the CrewAgentExecutor class.
|
||||
role: The role of the agent.
|
||||
goal: The objective of the agent.
|
||||
backstory: The backstory of the agent.
|
||||
@@ -68,8 +72,8 @@ class Agent(BaseModel):
|
||||
tools: List[Any] = Field(
|
||||
default_factory=list, description="Tools at agents disposal"
|
||||
)
|
||||
agent_executor: Optional[InstanceOf[AgentExecutor]] = Field(
|
||||
default=None, description="An instance of the AgentExecutor class."
|
||||
agent_executor: Optional[InstanceOf[CrewAgentExecutor]] = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
tools_handler: Optional[InstanceOf[ToolsHandler]] = Field(
|
||||
default=None, description="An instance of the ToolsHandler class."
|
||||
@@ -127,11 +131,11 @@ class Agent(BaseModel):
|
||||
self.tools_handler = ToolsHandler(cache=self.cache_handler)
|
||||
self.__create_agent_executor()
|
||||
|
||||
def __create_agent_executor(self) -> AgentExecutor:
|
||||
def __create_agent_executor(self) -> CrewAgentExecutor:
|
||||
"""Create an agent executor for the agent.
|
||||
|
||||
Returns:
|
||||
An instance of the AgentExecutor class.
|
||||
An instance of the CrewAgentExecutor class.
|
||||
"""
|
||||
agent_args = {
|
||||
"input": lambda x: x["input"],
|
||||
@@ -170,7 +174,7 @@ class Agent(BaseModel):
|
||||
tools_handler=self.tools_handler, cache=self.cache_handler
|
||||
)
|
||||
)
|
||||
self.agent_executor = AgentExecutor(agent=inner_agent, **executor_args)
|
||||
self.agent_executor = CrewAgentExecutor(agent=inner_agent, **executor_args)
|
||||
|
||||
@staticmethod
|
||||
def __tools_names(tools) -> str:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .cache_handler import CacheHandler
|
||||
from .executor import CrewAgentExecutor
|
||||
from .output_parser import CrewAgentOutputParser
|
||||
from .tools_handler import ToolsHandler
|
||||
|
||||
14
crewai/agents/cache_hit.py
Normal file
14
crewai/agents/cache_hit.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from langchain_core.agents import AgentAction
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
from .cache_handler import CacheHandler
|
||||
|
||||
|
||||
class CacheHit(BaseModel):
|
||||
"""Cache Hit Object."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
action: AgentAction = Field(description="Action taken")
|
||||
cache: CacheHandler = Field(description="Cache Handler for the tool")
|
||||
130
crewai/agents/executor.py
Normal file
130
crewai/agents/executor.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.agent import ExceptionTool
|
||||
from langchain.agents.tools import InvalidTool
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain_core.agents import AgentAction, AgentFinish, AgentStep
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from ..tools.cache_tools import CacheTools
|
||||
from .cache_hit import CacheHit
|
||||
|
||||
|
||||
class CrewAgentExecutor(AgentExecutor):
|
||||
def _iter_next_step(
|
||||
self,
|
||||
name_to_tool_map: Dict[str, BaseTool],
|
||||
color_mapping: Dict[str, str],
|
||||
inputs: Dict[str, str],
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Iterator[Union[AgentFinish, AgentAction, AgentStep]]:
|
||||
"""Take a single step in the thought-action-observation loop.
|
||||
|
||||
Override this to take control of how the agent makes and acts on choices.
|
||||
"""
|
||||
try:
|
||||
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
||||
|
||||
# Call the LLM to see what to do.
|
||||
output = self.agent.plan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
except OutputParserException as e:
|
||||
if isinstance(self.handle_parsing_errors, bool):
|
||||
raise_error = not self.handle_parsing_errors
|
||||
else:
|
||||
raise_error = False
|
||||
if raise_error:
|
||||
raise ValueError(
|
||||
"An output parsing error occurred. "
|
||||
"In order to pass this error back to the agent and have it try "
|
||||
"again, pass `handle_parsing_errors=True` to the AgentExecutor. "
|
||||
f"This is the error: {str(e)}"
|
||||
)
|
||||
text = str(e)
|
||||
if isinstance(self.handle_parsing_errors, bool):
|
||||
if e.send_to_llm:
|
||||
observation = str(e.observation)
|
||||
text = str(e.llm_output)
|
||||
else:
|
||||
observation = "Invalid or incomplete response"
|
||||
elif isinstance(self.handle_parsing_errors, str):
|
||||
observation = self.handle_parsing_errors
|
||||
elif callable(self.handle_parsing_errors):
|
||||
observation = self.handle_parsing_errors(e)
|
||||
else:
|
||||
raise ValueError("Got unexpected type of `handle_parsing_errors`")
|
||||
output = AgentAction("_Exception", observation, text)
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(output, color="green")
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
observation = ExceptionTool().run(
|
||||
output.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
yield AgentStep(action=output, observation=observation)
|
||||
return
|
||||
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
yield output
|
||||
return
|
||||
|
||||
# Override tool usage to use CacheTools
|
||||
if isinstance(output, CacheHit):
|
||||
cache = output.cache
|
||||
action = output.action
|
||||
tool = CacheTools(cache_handler=cache).tool()
|
||||
output = action.copy()
|
||||
output.tool_input = f"tool:{action.tool}|input:{action.tool_input}"
|
||||
output.tool = tool.name
|
||||
name_to_tool_map[tool.name] = tool
|
||||
color_mapping[tool.name] = color_mapping[action.tool]
|
||||
|
||||
actions: List[AgentAction]
|
||||
if isinstance(output, AgentAction):
|
||||
actions = [output]
|
||||
else:
|
||||
actions = output
|
||||
for agent_action in actions:
|
||||
yield agent_action
|
||||
for agent_action in actions:
|
||||
if run_manager:
|
||||
run_manager.on_agent_action(agent_action, color="green")
|
||||
# Otherwise we lookup the tool
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
return_direct = tool.return_direct
|
||||
color = color_mapping[agent_action.tool]
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = tool.run(
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
observation = InvalidTool().run(
|
||||
{
|
||||
"requested_tool_name": agent_action.tool,
|
||||
"available_tool_names": list(name_to_tool_map.keys()),
|
||||
},
|
||||
verbose=self.verbose,
|
||||
color=None,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
yield AgentStep(action=agent_action, observation=observation)
|
||||
@@ -6,6 +6,7 @@ from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
||||
from .cache_handler import CacheHandler
|
||||
from .cache_hit import CacheHit
|
||||
from .tools_handler import ToolsHandler
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
@@ -47,17 +48,13 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser):
|
||||
tools_handler: ToolsHandler
|
||||
cache: CacheHandler
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish, CacheHit]:
|
||||
FINAL_ANSWER_ACTION in text
|
||||
regex = (
|
||||
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
)
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
if action_match:
|
||||
if includes_answer:
|
||||
raise OutputParserException(
|
||||
f"{FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: {text}"
|
||||
)
|
||||
action = action_match.group(1).strip()
|
||||
action_input = action_match.group(2)
|
||||
tool_input = action_input.strip(" ")
|
||||
@@ -76,6 +73,7 @@ class CrewAgentOutputParser(ReActSingleInputOutputParser):
|
||||
|
||||
result = self.cache.read(action, tool_input)
|
||||
if result:
|
||||
return AgentFinish({"output": result}, text)
|
||||
action = AgentAction(action, tool_input, text)
|
||||
return CacheHit(action=action, cache=self.cache)
|
||||
|
||||
return super().parse(text)
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any, Dict
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
|
||||
from ..tools.cache_tools import CacheTools
|
||||
from .cache_handler import CacheHandler
|
||||
|
||||
|
||||
@@ -35,8 +36,9 @@ class ToolsHandler(BaseCallbackHandler):
|
||||
and "Invalid or incomplete response" not in output
|
||||
and "Invalid Format" not in output
|
||||
):
|
||||
self.cache.add(
|
||||
tool=self.last_used_tool["tool"],
|
||||
input=self.last_used_tool["input"],
|
||||
output=output,
|
||||
)
|
||||
if self.last_used_tool["tool"] != CacheTools().name:
|
||||
self.cache.add(
|
||||
tool=self.last_used_tool["tool"],
|
||||
input=self.last_used_tool["input"],
|
||||
output=output,
|
||||
)
|
||||
|
||||
28
crewai/tools/cache_tools.py
Normal file
28
crewai/tools/cache_tools.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from langchain.tools import Tool
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.agents import CacheHandler
|
||||
|
||||
|
||||
class CacheTools(BaseModel):
|
||||
"""Default tools to hit the cache."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
name: str = "Hit Cache"
|
||||
cache_handler: CacheHandler = Field(
|
||||
description="Cache Handler for the crew",
|
||||
default=CacheHandler(),
|
||||
)
|
||||
|
||||
def tool(self):
|
||||
return Tool.from_function(
|
||||
func=self.hit_cache,
|
||||
name=self.name,
|
||||
description="Reads directly from the cache",
|
||||
)
|
||||
|
||||
def hit_cache(self, key):
|
||||
split = key.split("tool:")
|
||||
tool = split[1].split("|input:")[0].strip()
|
||||
tool_input = split[1].split("|input:")[1].strip()
|
||||
return self.cache_handler.read(tool, tool_input)
|
||||
Reference in New Issue
Block a user