mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Better json parsing for smaller models
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import ast
|
import ast
|
||||||
|
from difflib import SequenceMatcher
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, List, Union
|
from typing import Any, List, Union
|
||||||
from difflib import SequenceMatcher
|
|
||||||
|
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
@@ -218,7 +218,10 @@ class ToolUsage:
|
|||||||
for tool in self.tools:
|
for tool in self.tools:
|
||||||
if (
|
if (
|
||||||
tool.name.lower().strip() == tool_name.lower().strip()
|
tool.name.lower().strip() == tool_name.lower().strip()
|
||||||
or SequenceMatcher(None, tool.name.lower().strip(), tool_name.lower().strip()).ratio() > 0.9
|
or SequenceMatcher(
|
||||||
|
None, tool.name.lower().strip(), tool_name.lower().strip()
|
||||||
|
).ratio()
|
||||||
|
> 0.9
|
||||||
):
|
):
|
||||||
return tool
|
return tool
|
||||||
self.task.increment_tools_errors()
|
self.task.increment_tools_errors()
|
||||||
@@ -314,7 +317,7 @@ class ToolUsage:
|
|||||||
|
|
||||||
return calling
|
return calling
|
||||||
|
|
||||||
def _validate_tool_input(self, tool_input: str) -> dict:
|
def _validate_tool_input(self, tool_input: str) -> str:
|
||||||
try:
|
try:
|
||||||
ast.literal_eval(tool_input)
|
ast.literal_eval(tool_input)
|
||||||
return tool_input
|
return tool_input
|
||||||
@@ -335,13 +338,17 @@ class ToolUsage:
|
|||||||
continue # Skip malformed entries
|
continue # Skip malformed entries
|
||||||
key, value = entry.split(":", 1)
|
key, value = entry.split(":", 1)
|
||||||
|
|
||||||
key = key.strip().strip(
|
# Remove extraneous white spaces and quotes, replace single quotes
|
||||||
'"'
|
key = key.strip().strip('"').replace("'", '"')
|
||||||
) # Remove extraneous white spaces and quotes
|
|
||||||
value = value.strip()
|
value = value.strip()
|
||||||
|
|
||||||
# Check and format the value based on its type
|
# Handle replacement of single quotes at the start and end of the value string
|
||||||
if value.isdigit(): # Check if value is a digit, hence integer
|
if value.startswith("'") and value.endswith("'"):
|
||||||
|
value = value[1:-1] # Remove single quotes
|
||||||
|
value = (
|
||||||
|
'"' + value.replace('"', '\\"') + '"'
|
||||||
|
) # Re-encapsulate with double quotes
|
||||||
|
elif value.isdigit(): # Check if value is a digit, hence integer
|
||||||
formatted_value = value
|
formatted_value = value
|
||||||
elif value.lower() in [
|
elif value.lower() in [
|
||||||
"true",
|
"true",
|
||||||
@@ -351,7 +358,7 @@ class ToolUsage:
|
|||||||
formatted_value = value.lower()
|
formatted_value = value.lower()
|
||||||
else:
|
else:
|
||||||
# Assume the value is a string and needs quotes
|
# Assume the value is a string and needs quotes
|
||||||
formatted_value = '"' + value.strip('"').replace('"', '\\"') + '"'
|
formatted_value = '"' + value.replace('"', '\\"') + '"'
|
||||||
|
|
||||||
# Rebuild the entry with proper quoting
|
# Rebuild the entry with proper quoting
|
||||||
formatted_entry = f'"{key}": {formatted_value}'
|
formatted_entry = f'"{key}": {formatted_value}'
|
||||||
|
|||||||
Reference in New Issue
Block a user