From 0b781065d277564077fdaf630d46995c210cc9d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Moura?= Date: Thu, 2 May 2024 21:57:41 -0300 Subject: [PATCH] Better json parsing for smaller models --- src/crewai/tools/tool_usage.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/crewai/tools/tool_usage.py b/src/crewai/tools/tool_usage.py index 6f78313d7..89e5103c3 100644 --- a/src/crewai/tools/tool_usage.py +++ b/src/crewai/tools/tool_usage.py @@ -1,7 +1,7 @@ import ast +from difflib import SequenceMatcher from textwrap import dedent from typing import Any, List, Union -from difflib import SequenceMatcher from langchain_core.tools import BaseTool from langchain_openai import ChatOpenAI @@ -218,7 +218,10 @@ class ToolUsage: for tool in self.tools: if ( tool.name.lower().strip() == tool_name.lower().strip() - or SequenceMatcher(None, tool.name.lower().strip(), tool_name.lower().strip()).ratio() > 0.9 + or SequenceMatcher( + None, tool.name.lower().strip(), tool_name.lower().strip() + ).ratio() + > 0.9 ): return tool self.task.increment_tools_errors() @@ -314,7 +317,7 @@ class ToolUsage: return calling - def _validate_tool_input(self, tool_input: str) -> dict: + def _validate_tool_input(self, tool_input: str) -> str: try: ast.literal_eval(tool_input) return tool_input @@ -335,13 +338,17 @@ class ToolUsage: continue # Skip malformed entries key, value = entry.split(":", 1) - key = key.strip().strip( - '"' - ) # Remove extraneous white spaces and quotes + # Remove extraneous white spaces and quotes, replace single quotes + key = key.strip().strip('"').replace("'", '"') value = value.strip() - # Check and format the value based on its type - if value.isdigit(): # Check if value is a digit, hence integer + # Handle replacement of single quotes at the start and end of the value string + if value.startswith("'") and value.endswith("'"): + value = value[1:-1] # Remove single quotes + value = ( + '"' + value.replace('"', '\\"') + '"' + ) # Re-encapsulate with double quotes + elif value.isdigit(): # Check if value is a digit, hence integer formatted_value = value elif value.lower() in [ "true", @@ -351,7 +358,7 @@ class ToolUsage: formatted_value = value.lower() else: # Assume the value is a string and needs quotes - formatted_value = '"' + value.strip('"').replace('"', '\\"') + '"' + formatted_value = '"' + value.replace('"', '\\"') + '"' # Rebuild the entry with proper quoting formatted_entry = f'"{key}": {formatted_value}'