mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
from typing import Any, Dict
|
|
|
|
from langchain.callbacks.base import BaseCallbackHandler
|
|
|
|
from ..tools.cache_tools import CacheTools
|
|
from .cache.cache_handler import CacheHandler
|
|
|
|
|
|
class ToolsHandler(BaseCallbackHandler):
|
|
"""Callback handler for tool usage."""
|
|
|
|
last_used_tool: Dict[str, Any] = {}
|
|
cache: CacheHandler
|
|
|
|
def __init__(self, cache: CacheHandler, **kwargs: Any):
|
|
"""Initialize the callback handler."""
|
|
self.cache = cache
|
|
super().__init__(**kwargs)
|
|
|
|
def on_tool_start(
|
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
|
) -> Any:
|
|
"""Run when tool starts running."""
|
|
name = serialized.get("name")
|
|
if name not in ["invalid_tool", "_Exception"]:
|
|
tools_usage = {
|
|
"tool": name,
|
|
"input": input_str,
|
|
}
|
|
self.last_used_tool = tools_usage
|
|
|
|
def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
|
"""Run when tool ends running."""
|
|
if (
|
|
"is not a valid tool" not in output
|
|
and "Invalid or incomplete response" not in output
|
|
and "Invalid Format" not in output
|
|
):
|
|
if self.last_used_tool["tool"] != CacheTools().name:
|
|
self.cache.add(
|
|
tool=self.last_used_tool["tool"],
|
|
input=self.last_used_tool["input"],
|
|
output=output,
|
|
)
|