mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
Revamping tool usage
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from .cache.cache_handler import CacheHandler
|
||||
from .executor import CrewAgentExecutor
|
||||
from .output_parser import CrewAgentOutputParser
|
||||
from .tools_handler import ToolsHandler
|
||||
|
||||
1
src/crewai/agents/cache/__init__.py
vendored
1
src/crewai/agents/cache/__init__.py
vendored
@@ -1,2 +1 @@
|
||||
from .cache_handler import CacheHandler
|
||||
from .cache_hit import CacheHit
|
||||
|
||||
2
src/crewai/agents/cache/cache_handler.py
vendored
2
src/crewai/agents/cache/cache_handler.py
vendored
@@ -10,9 +10,7 @@ class CacheHandler:
|
||||
self._cache = {}
|
||||
|
||||
def add(self, tool, input, output):
|
||||
input = input.strip()
|
||||
self._cache[f"{tool}-{input}"] = output
|
||||
|
||||
def read(self, tool, input) -> Optional[str]:
|
||||
input = input.strip()
|
||||
return self._cache.get(f"{tool}-{input}")
|
||||
|
||||
18
src/crewai/agents/cache/cache_hit.py
vendored
18
src/crewai/agents/cache/cache_hit.py
vendored
@@ -1,18 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .cache_handler import CacheHandler
|
||||
|
||||
|
||||
class CacheHit(BaseModel):
|
||||
"""Cache Hit Object."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# Making it Any instead of AgentAction to avoind
|
||||
# pydantic v1 vs v2 incompatibility, langchain should
|
||||
# soon be updated to pydantic v2
|
||||
action: Any = Field(description="Action taken")
|
||||
cache: CacheHandler = Field(description="Cache Handler for the tool")
|
||||
@@ -1,30 +0,0 @@
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
|
||||
from crewai.utilities import I18N
|
||||
|
||||
|
||||
class TaskRepeatedUsageException(OutputParserException):
|
||||
"""Exception raised when a task is used twice in a roll."""
|
||||
|
||||
i18n: I18N = I18N()
|
||||
error: str = "TaskRepeatedUsageException"
|
||||
message: str
|
||||
|
||||
def __init__(self, i18n: I18N, tool: str, tool_input: str, text: str):
|
||||
self.i18n = i18n
|
||||
self.text = text
|
||||
self.tool = tool
|
||||
self.tool_input = tool_input
|
||||
self.message = self.i18n.errors("task_repeated_usage").format(
|
||||
tool=tool, tool_input=tool_input
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
error=self.error,
|
||||
observation=self.message,
|
||||
send_to_llm=True,
|
||||
llm_output=self.text,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
@@ -10,16 +10,19 @@ from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.input import get_color_mapping
|
||||
from pydantic import InstanceOf
|
||||
|
||||
from crewai.agents.cache.cache_hit import CacheHit
|
||||
from crewai.tools.cache_tools import CacheTools
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities import I18N
|
||||
|
||||
|
||||
class CrewAgentExecutor(AgentExecutor):
|
||||
i18n: I18N = I18N()
|
||||
llm: Any = None
|
||||
iterations: int = 0
|
||||
request_within_rpm_limit: Any = None
|
||||
tools_handler: InstanceOf[ToolsHandler] = None
|
||||
max_iterations: Optional[int] = 15
|
||||
force_answer_max_iterations: Optional[int] = None
|
||||
step_callback: Optional[Any] = None
|
||||
@@ -32,11 +35,6 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
def _should_force_answer(self) -> bool:
|
||||
return True if self.iterations == self.force_answer_max_iterations else False
|
||||
|
||||
def _force_answer(self, output: AgentAction):
|
||||
return AgentStep(
|
||||
action=output, observation=self.i18n.errors("force_final_answer")
|
||||
)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
@@ -110,16 +108,17 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
|
||||
if self._should_force_answer():
|
||||
if isinstance(output, AgentAction) or isinstance(output, AgentFinish):
|
||||
output = output
|
||||
elif isinstance(output, CacheHit):
|
||||
output = output.action
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected output type from agent: {type(output)}"
|
||||
)
|
||||
yield self._force_answer(output)
|
||||
yield AgentStep(
|
||||
action=output, observation=self.i18n.errors("force_final_answer")
|
||||
)
|
||||
return
|
||||
|
||||
except OutputParserException as e:
|
||||
@@ -160,7 +159,9 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
)
|
||||
|
||||
if self._should_force_answer():
|
||||
yield self._force_answer(output)
|
||||
yield AgentStep(
|
||||
action=output, observation=self.i18n.errors("force_final_answer")
|
||||
)
|
||||
return
|
||||
|
||||
yield AgentStep(action=output, observation=observation)
|
||||
@@ -171,17 +172,6 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
yield output
|
||||
return
|
||||
|
||||
# Override tool usage to use CacheTools
|
||||
if isinstance(output, CacheHit):
|
||||
cache = output.cache
|
||||
action = output.action
|
||||
tool = CacheTools(cache_handler=cache).tool()
|
||||
output = action.copy()
|
||||
output.tool_input = f"tool:{action.tool}|input:{action.tool_input}"
|
||||
output.tool = tool.name
|
||||
name_to_tool_map[tool.name] = tool
|
||||
color_mapping[tool.name] = color_mapping[action.tool]
|
||||
|
||||
actions: List[AgentAction]
|
||||
actions = [output] if isinstance(output, AgentAction) else output
|
||||
yield from actions
|
||||
@@ -192,18 +182,13 @@ class CrewAgentExecutor(AgentExecutor):
|
||||
if agent_action.tool in name_to_tool_map:
|
||||
tool = name_to_tool_map[agent_action.tool]
|
||||
return_direct = tool.return_direct
|
||||
color = color_mapping[agent_action.tool]
|
||||
color_mapping[agent_action.tool]
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
if return_direct:
|
||||
tool_run_kwargs["llm_prefix"] = ""
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = tool.run(
|
||||
agent_action.tool_input,
|
||||
verbose=self.verbose,
|
||||
color=color,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**tool_run_kwargs,
|
||||
)
|
||||
observation = ToolUsage(
|
||||
tools_handler=self.tools_handler, tools=self.tools, llm=self.llm
|
||||
).use(agent_action.log)
|
||||
else:
|
||||
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
|
||||
observation = InvalidTool().run(
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from langchain.agents.output_parsers import ReActSingleInputOutputParser
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
|
||||
from crewai.agents.cache import CacheHandler, CacheHit
|
||||
from crewai.agents.exceptions import TaskRepeatedUsageException
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.utilities import I18N
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = (
|
||||
"Parsing LLM output produced both a final answer and a parse-able action:"
|
||||
)
|
||||
|
||||
|
||||
class CrewAgentOutputParser(ReActSingleInputOutputParser):
|
||||
"""Parses ReAct-style LLM calls that have a single tool input.
|
||||
|
||||
Expects output to be in one of two formats.
|
||||
|
||||
If the output signals that an action should be taken,
|
||||
should be in the below format. This will result in an AgentAction
|
||||
being returned.
|
||||
|
||||
```
|
||||
Thought: agent thought here
|
||||
Action: search
|
||||
Action Input: what is the temperature in SF?
|
||||
```
|
||||
|
||||
If the output signals that a final answer should be given,
|
||||
should be in the below format. This will result in an AgentFinish
|
||||
being returned.
|
||||
|
||||
```
|
||||
Thought: agent thought here
|
||||
Final Answer: The temperature is 100 degrees
|
||||
```
|
||||
|
||||
It also prevents tools from being reused in a roll.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
tools_handler: ToolsHandler
|
||||
cache: CacheHandler
|
||||
i18n: I18N
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish, CacheHit]:
|
||||
regex = (
|
||||
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
)
|
||||
if action_match := re.search(regex, text, re.DOTALL):
|
||||
action = action_match.group(1).strip()
|
||||
action_input = action_match.group(2)
|
||||
tool_input = action_input.strip(" ")
|
||||
tool_input = tool_input.strip('"')
|
||||
|
||||
if last_tool_usage := self.tools_handler.last_used_tool:
|
||||
usage = {
|
||||
"tool": action,
|
||||
"input": tool_input,
|
||||
}
|
||||
if usage == last_tool_usage:
|
||||
raise TaskRepeatedUsageException(
|
||||
text=text,
|
||||
tool=action,
|
||||
tool_input=tool_input,
|
||||
i18n=self.i18n,
|
||||
)
|
||||
|
||||
if self.cache.read(action, tool_input):
|
||||
action = AgentAction(action, tool_input, text)
|
||||
return CacheHit(action=action, cache=self.cache)
|
||||
|
||||
return super().parse(text)
|
||||
@@ -1,44 +1,30 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from typing import Any
|
||||
|
||||
from ..tools.cache_tools import CacheTools
|
||||
from ..tools.tool_calling import ToolCalling
|
||||
from .cache.cache_handler import CacheHandler
|
||||
|
||||
|
||||
class ToolsHandler(BaseCallbackHandler):
|
||||
class ToolsHandler:
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
last_used_tool: Dict[str, Any] = {}
|
||||
last_used_tool: ToolCalling = {}
|
||||
cache: CacheHandler
|
||||
|
||||
def __init__(self, cache: CacheHandler, **kwargs: Any):
|
||||
def __init__(self, cache: CacheHandler):
|
||||
"""Initialize the callback handler."""
|
||||
self.cache = cache
|
||||
super().__init__(**kwargs)
|
||||
self.last_used_tool = {}
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> Any:
|
||||
def on_tool_start(self, calling: ToolCalling) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
name = serialized.get("name")
|
||||
if name not in ["invalid_tool", "_Exception"]:
|
||||
tools_usage = {
|
||||
"tool": name,
|
||||
"input": input_str,
|
||||
}
|
||||
self.last_used_tool = tools_usage
|
||||
self.last_used_tool = calling
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
||||
def on_tool_end(self, calling: ToolCalling, output: str) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
if (
|
||||
"is not a valid tool" not in output
|
||||
and "Invalid or incomplete response" not in output
|
||||
and "Invalid Format" not in output
|
||||
):
|
||||
if self.last_used_tool["tool"] != CacheTools().name:
|
||||
self.cache.add(
|
||||
tool=self.last_used_tool["tool"],
|
||||
input=self.last_used_tool["input"],
|
||||
output=output,
|
||||
)
|
||||
if self.last_used_tool.function_name != CacheTools().name:
|
||||
self.cache.add(
|
||||
tool=calling.function_name,
|
||||
input=calling.arguments,
|
||||
output=output,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user