Better json parsing for smaller models

This commit is contained in:
João Moura
2024-05-02 21:57:41 -03:00
parent bcb57ce5f9
commit 0b781065d2

View File

@@ -1,7 +1,7 @@
import ast import ast
from difflib import SequenceMatcher
from textwrap import dedent from textwrap import dedent
from typing import Any, List, Union from typing import Any, List, Union
from difflib import SequenceMatcher
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
@@ -218,7 +218,10 @@ class ToolUsage:
for tool in self.tools: for tool in self.tools:
if ( if (
tool.name.lower().strip() == tool_name.lower().strip() tool.name.lower().strip() == tool_name.lower().strip()
or SequenceMatcher(None, tool.name.lower().strip(), tool_name.lower().strip()).ratio() > 0.9 or SequenceMatcher(
None, tool.name.lower().strip(), tool_name.lower().strip()
).ratio()
> 0.9
): ):
return tool return tool
self.task.increment_tools_errors() self.task.increment_tools_errors()
@@ -314,7 +317,7 @@ class ToolUsage:
return calling return calling
def _validate_tool_input(self, tool_input: str) -> dict: def _validate_tool_input(self, tool_input: str) -> str:
try: try:
ast.literal_eval(tool_input) ast.literal_eval(tool_input)
return tool_input return tool_input
@@ -335,13 +338,17 @@ class ToolUsage:
continue # Skip malformed entries continue # Skip malformed entries
key, value = entry.split(":", 1) key, value = entry.split(":", 1)
key = key.strip().strip( # Remove extraneous white spaces and quotes, replace single quotes
'"' key = key.strip().strip('"').replace("'", '"')
) # Remove extraneous white spaces and quotes
value = value.strip() value = value.strip()
# Check and format the value based on its type # Handle replacement of single quotes at the start and end of the value string
if value.isdigit(): # Check if value is a digit, hence integer if value.startswith("'") and value.endswith("'"):
value = value[1:-1] # Remove single quotes
value = (
'"' + value.replace('"', '\\"') + '"'
) # Re-encapsulate with double quotes
elif value.isdigit(): # Check if value is a digit, hence integer
formatted_value = value formatted_value = value
elif value.lower() in [ elif value.lower() in [
"true", "true",
@@ -351,7 +358,7 @@ class ToolUsage:
formatted_value = value.lower() formatted_value = value.lower()
else: else:
# Assume the value is a string and needs quotes # Assume the value is a string and needs quotes
formatted_value = '"' + value.strip('"').replace('"', '\\"') + '"' formatted_value = '"' + value.replace('"', '\\"') + '"'
# Rebuild the entry with proper quoting # Rebuild the entry with proper quoting
formatted_entry = f'"{key}": {formatted_value}' formatted_entry = f'"{key}": {formatted_value}'