Merge pull request #1382 from crewAIInc/tm-basic-event-structure

Add tool usage events
This commit is contained in:
Thiago Moretto
2024-10-02 12:54:51 -03:00
committed by GitHub
4 changed files with 134 additions and 6 deletions

View File

@@ -1,5 +1,7 @@
import ast import ast
import datetime
import os import os
import time
from difflib import SequenceMatcher from difflib import SequenceMatcher
from textwrap import dedent from textwrap import dedent
from typing import Any, List, Union from typing import Any, List, Union
@@ -8,7 +10,10 @@ from crewai.agents.tools_handler import ToolsHandler
from crewai.task import Task from crewai.task import Task
from crewai.telemetry import Telemetry from crewai.telemetry import Telemetry
from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling from crewai.tools.tool_calling import InstructorToolCalling, ToolCalling
from crewai.tools.tool_usage_events import ToolUsageError, ToolUsageFinished
from crewai.utilities import I18N, Converter, ConverterError, Printer from crewai.utilities import I18N, Converter, ConverterError, Printer
import crewai.utilities.events as events
agentops = None agentops = None
if os.environ.get("AGENTOPS_API_KEY"): if os.environ.get("AGENTOPS_API_KEY"):
@@ -126,12 +131,16 @@ class ToolUsage:
except Exception: except Exception:
self.task.increment_tools_errors() self.task.increment_tools_errors()
result = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str") started_at = time.time()
from_cache = False
result = None # type: ignore # Incompatible types in assignment (expression has type "None", variable has type "str")
# check if cache is available
if self.tools_handler.cache: if self.tools_handler.cache:
result = self.tools_handler.cache.read( # type: ignore # Incompatible types in assignment (expression has type "str | None", variable has type "str") result = self.tools_handler.cache.read( # type: ignore # Incompatible types in assignment (expression has type "str | None", variable has type "str")
tool=calling.tool_name, input=calling.arguments tool=calling.tool_name, input=calling.arguments
) )
from_cache = result is not None
original_tool = next( original_tool = next(
(ot for ot in self.original_tools if ot.name == tool.name), None (ot for ot in self.original_tools if ot.name == tool.name), None
@@ -163,6 +172,7 @@ class ToolUsage:
else: else:
result = tool.invoke(input={}) result = tool.invoke(input={})
except Exception as e: except Exception as e:
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
self._run_attempts += 1 self._run_attempts += 1
if self._run_attempts > self._max_parsing_attempts: if self._run_attempts > self._max_parsing_attempts:
self._telemetry.tool_usage_error(llm=self.function_calling_llm) self._telemetry.tool_usage_error(llm=self.function_calling_llm)
@@ -214,6 +224,13 @@ class ToolUsage:
"tool_args": calling.arguments, "tool_args": calling.arguments,
} }
self.on_tool_use_finished(
tool=tool,
tool_calling=calling,
from_cache=from_cache,
started_at=started_at,
)
if ( if (
hasattr(original_tool, "result_as_answer") hasattr(original_tool, "result_as_answer")
and original_tool.result_as_answer # type: ignore # Item "None" of "Any | None" has no attribute "cache_function" and original_tool.result_as_answer # type: ignore # Item "None" of "Any | None" has no attribute "cache_function"
@@ -431,3 +448,34 @@ class ToolUsage:
# Reconstruct the JSON string # Reconstruct the JSON string
new_json_string = "{" + ", ".join(formatted_entries) + "}" new_json_string = "{" + ", ".join(formatted_entries) + "}"
return new_json_string return new_json_string
def on_tool_error(self, tool: Any, tool_calling: ToolCalling, e: Exception) -> None:
event_data = self._prepare_event_data(tool, tool_calling)
events.emit(
source=self, event=ToolUsageError(**{**event_data, "error": str(e)})
)
def on_tool_use_finished(
self, tool: Any, tool_calling: ToolCalling, from_cache: bool, started_at: float
) -> None:
finished_at = time.time()
event_data = self._prepare_event_data(tool, tool_calling)
event_data.update(
{
"started_at": datetime.datetime.fromtimestamp(started_at),
"finished_at": datetime.datetime.fromtimestamp(finished_at),
"from_cache": from_cache,
}
)
events.emit(source=self, event=ToolUsageFinished(**event_data))
def _prepare_event_data(self, tool: Any, tool_calling: ToolCalling) -> dict:
return {
"agent_key": self.agent.key,
"agent_role": (self.agent._original_role or self.agent.role),
"run_attempts": self._run_attempts,
"delegations": self.task.delegations,
"tool_name": tool.name,
"tool_args": tool_calling.arguments,
"tool_class": tool.__class__.__name__,
}

View File

@@ -0,0 +1,23 @@
from typing import Any, Dict
from pydantic import BaseModel
from datetime import datetime
class ToolUsageEvent(BaseModel):
agent_key: str
agent_role: str
tool_name: str
tool_args: Dict[str, Any]
tool_class: str
run_attempts: int | None = None
delegations: int | None = None
class ToolUsageFinished(ToolUsageEvent):
started_at: datetime
finished_at: datetime
from_cache: bool = False
class ToolUsageError(ToolUsageEvent):
error: str

View File

@@ -0,0 +1,44 @@
from typing import Any, Callable, Generic, List, Dict, Type, TypeVar
from functools import wraps
from pydantic import BaseModel
T = TypeVar("T")
EVT = TypeVar("EVT", bound=BaseModel)
class Emitter(Generic[T, EVT]):
_listeners: Dict[Type[EVT], List[Callable]] = {}
def on(self, event_type: Type[EVT]):
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
self._listeners.setdefault(event_type, []).append(wrapper)
return wrapper
return decorator
def emit(self, source: T, event: EVT) -> None:
event_type = type(event)
for func in self._listeners.get(event_type, []):
func(source, event)
default_emitter = Emitter[Any, BaseModel]()
def emit(source: Any, event: BaseModel, raise_on_error: bool = False) -> None:
try:
default_emitter.emit(source, event)
except Exception as e:
if raise_on_error:
raise e
else:
print(f"Error emitting event: {e}")
def on(event_type: Type[BaseModel]) -> Callable:
return default_emitter.on(event_type)

View File

@@ -12,9 +12,11 @@ from crewai.llm import LLM
from crewai.agents.parser import CrewAgentParser, OutputParserException from crewai.agents.parser import CrewAgentParser, OutputParserException
from crewai.tools.tool_calling import InstructorToolCalling from crewai.tools.tool_calling import InstructorToolCalling
from crewai.tools.tool_usage import ToolUsage from crewai.tools.tool_usage import ToolUsage
from crewai.tools.tool_usage_events import ToolUsageFinished
from crewai.utilities import RPMController from crewai.utilities import RPMController
from crewai_tools import tool from crewai_tools import tool
from crewai.agents.parser import AgentAction from crewai.agents.parser import AgentAction
from crewai.utilities.events import Emitter
def test_agent_llm_creation_with_env_vars(): def test_agent_llm_creation_with_env_vars():
@@ -71,7 +73,7 @@ def test_agent_creation():
def test_agent_default_values(): def test_agent_default_values():
agent = Agent(role="test role", goal="test goal", backstory="test backstory") agent = Agent(role="test role", goal="test goal", backstory="test backstory")
assert agent.llm.model == "gpt-4o" assert agent.llm.model == "gpt-4o-mini"
assert agent.allow_delegation is False assert agent.allow_delegation is False
@@ -178,8 +180,15 @@ def test_agent_execution_with_tools():
agent=agent, agent=agent,
expected_output="The result of the multiplication.", expected_output="The result of the multiplication.",
) )
output = agent.execute_task(task) with patch.object(Emitter, "emit") as emit:
assert output == "The result of the multiplication is 12." output = agent.execute_task(task)
assert output == "The result of the multiplication is 12."
assert emit.call_count == 1
args, _ = emit.call_args
assert isinstance(args[1], ToolUsageFinished)
assert not args[1].from_cache
assert args[1].tool_name == "multiplier"
assert args[1].tool_args == {"first_number": 3, "second_number": 4}
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@@ -197,7 +206,7 @@ def test_logging_tool_usage():
verbose=True, verbose=True,
) )
assert agent.llm.model == "gpt-4o" assert agent.llm.model == "gpt-4o-mini"
assert agent.tools_handler.last_used_tool == {} assert agent.tools_handler.last_used_tool == {}
task = Task( task = Task(
description="What is 3 times 4?", description="What is 3 times 4?",
@@ -267,7 +276,7 @@ def test_cache_hitting():
"multiplier-{'first_number': 12, 'second_number': 3}": 36, "multiplier-{'first_number': 12, 'second_number': 3}": 36,
} }
with patch.object(CacheHandler, "read") as read: with patch.object(CacheHandler, "read") as read, patch.object(Emitter, "emit") as emit:
read.return_value = "0" read.return_value = "0"
task = Task( task = Task(
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.", description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
@@ -279,6 +288,10 @@ def test_cache_hitting():
read.assert_called_with( read.assert_called_with(
tool="multiplier", input={"first_number": 2, "second_number": 6} tool="multiplier", input={"first_number": 2, "second_number": 6}
) )
assert emit.call_count == 1
args, _ = emit.call_args
assert isinstance(args[1], ToolUsageFinished)
assert args[1].from_cache
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])