mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
smal fixes and better guardrail for parsing small models tools usage
This commit is contained in:
@@ -26,13 +26,13 @@ class ToolUsage:
|
|||||||
Class that represents the usage of a tool by an agent.
|
Class that represents the usage of a tool by an agent.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
task: Task being executed.
|
task: Task being executed.
|
||||||
tools_handler: Tools handler that will manage the tool usage.
|
tools_handler: Tools handler that will manage the tool usage.
|
||||||
tools: List of tools available for the agent.
|
tools: List of tools available for the agent.
|
||||||
original_tools: Original tools available for the agent before being converted to BaseTool.
|
original_tools: Original tools available for the agent before being converted to BaseTool.
|
||||||
tools_description: Description of the tools available for the agent.
|
tools_description: Description of the tools available for the agent.
|
||||||
tools_names: Names of the tools available for the agent.
|
tools_names: Names of the tools available for the agent.
|
||||||
function_calling_llm: Language model to be used for the tool usage.
|
function_calling_llm: Language model to be used for the tool usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -265,12 +265,12 @@ class ToolUsage:
|
|||||||
model=model,
|
model=model,
|
||||||
instructions=dedent(
|
instructions=dedent(
|
||||||
"""\
|
"""\
|
||||||
The schema should have the following structure, only two keys:
|
The schema should have the following structure, only two keys:
|
||||||
- tool_name: str
|
- tool_name: str
|
||||||
- arguments: dict (with all arguments being passed)
|
- arguments: dict (with all arguments being passed)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
{"tool_name": "tool name", "arguments": {"arg_name1": "value", "arg_name2": 2}}""",
|
{"tool_name": "tool name", "arguments": {"arg_name1": "value", "arg_name2": 2}}""",
|
||||||
),
|
),
|
||||||
max_attemps=1,
|
max_attemps=1,
|
||||||
)
|
)
|
||||||
@@ -282,7 +282,8 @@ class ToolUsage:
|
|||||||
tool_name = self.action.tool
|
tool_name = self.action.tool
|
||||||
tool = self._select_tool(tool_name)
|
tool = self._select_tool(tool_name)
|
||||||
try:
|
try:
|
||||||
arguments = ast.literal_eval(self.action.tool_input)
|
tool_input = self._validate_tool_input(self.action.tool_input)
|
||||||
|
arguments = ast.literal_eval(tool_input)
|
||||||
except Exception:
|
except Exception:
|
||||||
return ToolUsageErrorException(
|
return ToolUsageErrorException(
|
||||||
f'{self._i18n.errors("tool_arguments_error")}'
|
f'{self._i18n.errors("tool_arguments_error")}'
|
||||||
@@ -308,3 +309,50 @@ class ToolUsage:
|
|||||||
return self._tool_calling(tool_string)
|
return self._tool_calling(tool_string)
|
||||||
|
|
||||||
return calling
|
return calling
|
||||||
|
|
||||||
|
def _validate_tool_input(self, tool_input: str) -> dict:
|
||||||
|
try:
|
||||||
|
ast.literal_eval(tool_input)
|
||||||
|
return tool_input
|
||||||
|
except Exception:
|
||||||
|
# Clean and ensure the string is properly enclosed in braces
|
||||||
|
tool_input = tool_input.strip()
|
||||||
|
if not tool_input.startswith("{"):
|
||||||
|
tool_input = "{" + tool_input
|
||||||
|
if not tool_input.endswith("}"):
|
||||||
|
tool_input += "}"
|
||||||
|
|
||||||
|
# Manually split the input into key-value pairs
|
||||||
|
entries = tool_input.strip("{} ").split(",")
|
||||||
|
formatted_entries = []
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
if ":" not in entry:
|
||||||
|
continue # Skip malformed entries
|
||||||
|
key, value = entry.split(":", 1)
|
||||||
|
|
||||||
|
key = key.strip().strip(
|
||||||
|
'"'
|
||||||
|
) # Remove extraneous white spaces and quotes
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
# Check and format the value based on its type
|
||||||
|
if value.isdigit(): # Check if value is a digit, hence integer
|
||||||
|
formatted_value = value
|
||||||
|
elif value.lower() in [
|
||||||
|
"true",
|
||||||
|
"false",
|
||||||
|
"null",
|
||||||
|
]: # Check for boolean and null values
|
||||||
|
formatted_value = value.lower()
|
||||||
|
else:
|
||||||
|
# Assume the value is a string and needs quotes
|
||||||
|
formatted_value = '"' + value.strip('"').replace('"', '\\"') + '"'
|
||||||
|
|
||||||
|
# Rebuild the entry with proper quoting
|
||||||
|
formatted_entry = f'"{key}": {formatted_value}'
|
||||||
|
formatted_entries.append(formatted_entry)
|
||||||
|
|
||||||
|
# Reconstruct the JSON string
|
||||||
|
new_json_string = "{" + ", ".join(formatted_entries) + "}"
|
||||||
|
return new_json_string
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -648,10 +648,10 @@ def test_agent_usage_metrics_are_captured_for_sequential_process():
|
|||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
assert result == "Howdy!"
|
assert result == "Howdy!"
|
||||||
assert crew.usage_metrics == {
|
assert crew.usage_metrics == {
|
||||||
"completion_tokens": 51,
|
"completion_tokens": 17,
|
||||||
"prompt_tokens": 483,
|
"prompt_tokens": 160,
|
||||||
"successful_requests": 3,
|
"successful_requests": 1,
|
||||||
"total_tokens": 534,
|
"total_tokens": 177,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -678,10 +678,10 @@ def test_agent_usage_metrics_are_captured_for_hierarchical_process():
|
|||||||
result = crew.kickoff()
|
result = crew.kickoff()
|
||||||
assert result == '"Howdy!"'
|
assert result == '"Howdy!"'
|
||||||
assert crew.usage_metrics == {
|
assert crew.usage_metrics == {
|
||||||
"total_tokens": 2592,
|
"total_tokens": 1650,
|
||||||
"prompt_tokens": 2048,
|
"prompt_tokens": 1367,
|
||||||
"completion_tokens": 544,
|
"completion_tokens": 283,
|
||||||
"successful_requests": 6,
|
"successful_requests": 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user