Improving tool calling to pass dictionaries instead of strings

This commit is contained in:
Brandon Hancock
2025-01-08 15:22:55 -05:00
parent b3504e768c
commit 2107512e84
2 changed files with 47 additions and 73 deletions

View File

@@ -130,6 +130,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
try: try:
self._format_answer(answer) self._format_answer(answer)
except OutputParserException as e: except OutputParserException as e:
print("ERROR ATTEMPTING TO PARSE ANSWER: ", answer)
if ( if (
FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE
in e.error in e.error
@@ -147,7 +148,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
# Directly append the result to the messages if the # Directly append the result to the messages if the
# tool is "Add image to content" in case of multimodal # tool is "Add image to content" in case of multimodal
# agents # agents
if formatted_answer.tool == self._i18n.tools("add_image")["name"]: if (
formatted_answer.tool
== self._i18n.tools("add_image")["name"]
):
self.messages.append(tool_result.result) self.messages.append(tool_result.result)
continue continue
@@ -155,7 +159,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.step_callback: if self.step_callback:
self.step_callback(tool_result) self.step_callback(tool_result)
formatted_answer.text += f"\nObservation: {tool_result.result}" formatted_answer.text += (
f"\nObservation: {tool_result.result}"
)
formatted_answer.result = tool_result.result formatted_answer.result = tool_result.result
if tool_result.result_as_answer: if tool_result.result_as_answer:
@@ -272,7 +278,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
agent=self.agent, agent=self.agent,
action=agent_action, action=agent_action,
) )
tool_calling = tool_usage.parse(agent_action.text) tool_calling = tool_usage.parse_tool_calling(agent_action.text)
if isinstance(tool_calling, ToolUsageErrorException): if isinstance(tool_calling, ToolUsageErrorException):
tool_result = tool_calling.message tool_result = tool_calling.message

View File

@@ -1,9 +1,12 @@
import ast import ast
import datetime import datetime
import json
import time import time
from difflib import SequenceMatcher from difflib import SequenceMatcher
from textwrap import dedent from textwrap import dedent
from typing import Any, List, Union from typing import Any, Dict, List, Union
from json_repair import repair_json
import crewai.utilities.events as events import crewai.utilities.events as events
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
@@ -19,7 +22,15 @@ try:
import agentops # type: ignore import agentops # type: ignore
except ImportError: except ImportError:
agentops = None agentops = None
OPENAI_BIGGER_MODELS = ["gpt-4", "gpt-4o", "o1-preview", "o1-mini", "o1", "o3", "o3-mini"] OPENAI_BIGGER_MODELS = [
"gpt-4",
"gpt-4o",
"o1-preview",
"o1-mini",
"o1",
"o3",
"o3-mini",
]
class ToolUsageErrorException(Exception): class ToolUsageErrorException(Exception):
@@ -80,7 +91,7 @@ class ToolUsage:
self._max_parsing_attempts = 2 self._max_parsing_attempts = 2
self._remember_format_after_usages = 4 self._remember_format_after_usages = 4
def parse(self, tool_string: str): def parse_tool_calling(self, tool_string: str):
"""Parse the tool string and return the tool calling.""" """Parse the tool string and return the tool calling."""
return self._tool_calling(tool_string) return self._tool_calling(tool_string)
@@ -94,7 +105,6 @@ class ToolUsage:
self.task.increment_tools_errors() self.task.increment_tools_errors()
return error return error
# BUG? The code below seems to be unreachable
try: try:
tool = self._select_tool(calling.tool_name) tool = self._select_tool(calling.tool_name)
except Exception as e: except Exception as e:
@@ -104,19 +114,7 @@ class ToolUsage:
self._printer.print(content=f"\n\n{error}\n", color="red") self._printer.print(content=f"\n\n{error}\n", color="red")
return error return error
if isinstance(tool, CrewStructuredTool) and tool.name == self._i18n.tools("add_image")["name"]: # type: ignore return f"{self._use(tool_string=tool_string, tool=tool, calling=calling)}"
try:
result = self._use(tool_string=tool_string, tool=tool, calling=calling)
return result
except Exception as e:
error = getattr(e, "message", str(e))
self.task.increment_tools_errors()
if self.agent.verbose:
self._printer.print(content=f"\n\n{error}\n", color="red")
return error
return f"{self._use(tool_string=tool_string, tool=tool, calling=calling)}" # type: ignore # BUG?: "_use" of "ToolUsage" does not return a value (it only ever returns None)
def _use( def _use(
self, self,
@@ -349,13 +347,14 @@ class ToolUsage:
tool_name = self.action.tool tool_name = self.action.tool
tool = self._select_tool(tool_name) tool = self._select_tool(tool_name)
try: try:
tool_input = self._validate_tool_input(self.action.tool_input) arguments = self._validate_tool_input(self.action.tool_input)
arguments = ast.literal_eval(tool_input) print("Arguments:", arguments)
print("Arguments type:", type(arguments))
except Exception: except Exception:
if raise_error: if raise_error:
raise raise
else: else:
return ToolUsageErrorException( # type: ignore # Incompatible return value type (got "ToolUsageErrorException", expected "ToolCalling | InstructorToolCalling") return ToolUsageErrorException(
f'{self._i18n.errors("tool_arguments_error")}' f'{self._i18n.errors("tool_arguments_error")}'
) )
@@ -363,14 +362,14 @@ class ToolUsage:
if raise_error: if raise_error:
raise raise
else: else:
return ToolUsageErrorException( # type: ignore # Incompatible return value type (got "ToolUsageErrorException", expected "ToolCalling | InstructorToolCalling") return ToolUsageErrorException(
f'{self._i18n.errors("tool_arguments_error")}' f'{self._i18n.errors("tool_arguments_error")}'
) )
return ToolCalling( return ToolCalling(
tool_name=tool.name, tool_name=tool.name,
arguments=arguments, arguments=arguments,
log=tool_string, # type: ignore log=tool_string,
) )
def _tool_calling( def _tool_calling(
@@ -396,57 +395,26 @@ class ToolUsage:
) )
return self._tool_calling(tool_string) return self._tool_calling(tool_string)
def _validate_tool_input(self, tool_input: str) -> str: def _validate_tool_input(self, tool_input: str) -> Dict[str, Any]:
print("tool_input:", tool_input)
try: try:
ast.literal_eval(tool_input) # Try to parse with json.loads directly
return tool_input arguments = json.loads(tool_input)
except Exception: return arguments
# Clean and ensure the string is properly enclosed in braces except json.JSONDecodeError:
tool_input = tool_input.strip() # Fix common issues in the tool_input string
if not tool_input.startswith("{"):
tool_input = "{" + tool_input
if not tool_input.endswith("}"):
tool_input += "}"
# Manually split the input into key-value pairs # Replace single quotes with double quotes
entries = tool_input.strip("{} ").split(",") tool_input = tool_input.replace("'", '"')
formatted_entries = []
for entry in entries: # Use json_repair to fix common JSON issues
if ":" not in entry: repaired_input = repair_json(tool_input)
continue # Skip malformed entries try:
key, value = entry.split(":", 1) arguments = json.loads(repaired_input)
return arguments
# Remove extraneous white spaces and quotes, replace single quotes except json.JSONDecodeError as e:
key = key.strip().strip('"').replace("'", '"') # If all else fails, raise an error
value = value.strip() raise Exception(f"Invalid tool input JSON: {e}")
# Handle replacement of single quotes at the start and end of the value string
if value.startswith("'") and value.endswith("'"):
value = value[1:-1] # Remove single quotes
value = (
'"' + value.replace('"', '\\"') + '"'
) # Re-encapsulate with double quotes
elif value.isdigit(): # Check if value is a digit, hence integer
value = value
elif value.lower() in [
"true",
"false",
]: # Check for boolean and null values
value = value.lower().capitalize()
elif value.lower() == "null":
value = "None"
else:
# Assume the value is a string and needs quotes
value = '"' + value.replace('"', '\\"') + '"'
# Rebuild the entry with proper quoting
formatted_entry = f'"{key}": {value}'
formatted_entries.append(formatted_entry)
# Reconstruct the JSON string
new_json_string = "{" + ", ".join(formatted_entries) + "}"
return new_json_string
def on_tool_error(self, tool: Any, tool_calling: ToolCalling, e: Exception) -> None: def on_tool_error(self, tool: Any, tool_calling: ToolCalling, e: Exception) -> None:
event_data = self._prepare_event_data(tool, tool_calling) event_data = self._prepare_event_data(tool, tool_calling)