mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Revamping tool usage
This commit is contained in:
@@ -1,44 +1,30 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from typing import Any
|
||||
|
||||
from ..tools.cache_tools import CacheTools
|
||||
from ..tools.tool_calling import ToolCalling
|
||||
from .cache.cache_handler import CacheHandler
|
||||
|
||||
|
||||
class ToolsHandler(BaseCallbackHandler):
|
||||
class ToolsHandler:
|
||||
"""Callback handler for tool usage."""
|
||||
|
||||
last_used_tool: Dict[str, Any] = {}
|
||||
last_used_tool: ToolCalling = {}
|
||||
cache: CacheHandler
|
||||
|
||||
def __init__(self, cache: CacheHandler, **kwargs: Any):
|
||||
def __init__(self, cache: CacheHandler):
|
||||
"""Initialize the callback handler."""
|
||||
self.cache = cache
|
||||
super().__init__(**kwargs)
|
||||
self.last_used_tool = {}
|
||||
|
||||
def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> Any:
|
||||
def on_tool_start(self, calling: ToolCalling) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
name = serialized.get("name")
|
||||
if name not in ["invalid_tool", "_Exception"]:
|
||||
tools_usage = {
|
||||
"tool": name,
|
||||
"input": input_str,
|
||||
}
|
||||
self.last_used_tool = tools_usage
|
||||
self.last_used_tool = calling
|
||||
|
||||
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
||||
def on_tool_end(self, calling: ToolCalling, output: str) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
if (
|
||||
"is not a valid tool" not in output
|
||||
and "Invalid or incomplete response" not in output
|
||||
and "Invalid Format" not in output
|
||||
):
|
||||
if self.last_used_tool["tool"] != CacheTools().name:
|
||||
self.cache.add(
|
||||
tool=self.last_used_tool["tool"],
|
||||
input=self.last_used_tool["input"],
|
||||
output=output,
|
||||
)
|
||||
if self.last_used_tool.function_name != CacheTools().name:
|
||||
self.cache.add(
|
||||
tool=calling.function_name,
|
||||
input=calling.arguments,
|
||||
output=output,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user