From dea6ed7ef0b4c1cc023fa0e91dfcc98b8d53d1ae Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Mon, 27 Jan 2025 17:35:17 -0500 Subject: [PATCH] fix issue pointed out by mike (#1986) * fix issue pointed out by mike * clean up * Drop logger * drop unused imports --- pyproject.toml | 1 + src/crewai/tools/tool_usage.py | 70 ++++++--- tests/tools/test_tool_usage.py | 252 +++++++++++++++++++++++++++++++++ uv.lock | 11 ++ 4 files changed, 313 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8dbe56fd8..4a9343697 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "tomli-w>=1.1.0", "tomli>=2.0.2", "blinker>=1.9.0", + "json5>=0.10.0", ] [project.urls] diff --git a/src/crewai/tools/tool_usage.py b/src/crewai/tools/tool_usage.py index a59ed7b50..218410ef7 100644 --- a/src/crewai/tools/tool_usage.py +++ b/src/crewai/tools/tool_usage.py @@ -1,12 +1,13 @@ import ast import datetime import json -import re import time from difflib import SequenceMatcher +from json import JSONDecodeError from textwrap import dedent -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union +import json5 from json_repair import repair_json import crewai.utilities.events as events @@ -407,28 +408,55 @@ class ToolUsage: ) return self._tool_calling(tool_string) - def _validate_tool_input(self, tool_input: str) -> Dict[str, Any]: + def _validate_tool_input(self, tool_input: Optional[str]) -> Dict[str, Any]: + if tool_input is None: + return {} + + if not isinstance(tool_input, str) or not tool_input.strip(): + raise Exception( + "Tool input must be a valid dictionary in JSON or Python literal format" + ) + + # Attempt 1: Parse as JSON try: - # Replace Python literals with JSON equivalents - replacements = { - r"'": '"', - r"None": "null", - r"True": "true", - r"False": "false", - } - for pattern, replacement in replacements.items(): - tool_input = re.sub(pattern, replacement, tool_input) - arguments = json.loads(tool_input) - except json.JSONDecodeError: - # Attempt to repair JSON string - repaired_input = repair_json(tool_input) - try: - arguments = json.loads(repaired_input) - except json.JSONDecodeError as e: - raise Exception(f"Invalid tool input JSON: {e}") + if isinstance(arguments, dict): + return arguments + except (JSONDecodeError, TypeError): + pass # Continue to the next parsing attempt - return arguments + # Attempt 2: Parse as Python literal + try: + arguments = ast.literal_eval(tool_input) + if isinstance(arguments, dict): + return arguments + except (ValueError, SyntaxError): + pass # Continue to the next parsing attempt + + # Attempt 3: Parse as JSON5 + try: + arguments = json5.loads(tool_input) + if isinstance(arguments, dict): + return arguments + except (JSONDecodeError, ValueError, TypeError): + pass # Continue to the next parsing attempt + + # Attempt 4: Repair JSON + try: + repaired_input = repair_json(tool_input) + self._printer.print( + content=f"Repaired JSON: {repaired_input}", color="blue" + ) + arguments = json.loads(repaired_input) + if isinstance(arguments, dict): + return arguments + except Exception as e: + self._printer.print(content=f"Failed to repair JSON: {e}", color="red") + + # If all parsing attempts fail, raise an error + raise Exception( + "Tool input must be a valid dictionary in JSON or Python literal format" + ) def on_tool_error(self, tool: Any, tool_calling: ToolCalling, e: Exception) -> None: event_data = self._prepare_event_data(tool, tool_calling) diff --git a/tests/tools/test_tool_usage.py b/tests/tools/test_tool_usage.py index 952011339..7b2ccd416 100644 --- a/tests/tools/test_tool_usage.py +++ b/tests/tools/test_tool_usage.py @@ -231,3 +231,255 @@ def test_validate_tool_input_with_special_characters(): arguments = tool_usage._validate_tool_input(tool_input) assert arguments == expected_arguments + + +def test_validate_tool_input_none_input(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + arguments = tool_usage._validate_tool_input(None) + assert arguments == {} + + +def test_validate_tool_input_valid_json(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = '{"key": "value", "number": 42, "flag": true}' + expected_arguments = {"key": "value", "number": 42, "flag": True} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_python_dict(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = "{'key': 'value', 'number': 42, 'flag': True}" + expected_arguments = {"key": "value", "number": 42, "flag": True} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_json5_unquoted_keys(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = "{key: 'value', number: 42, flag: true}" + expected_arguments = {"key": "value", "number": 42, "flag": True} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_with_trailing_commas(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = '{"key": "value", "number": 42, "flag": true,}' + expected_arguments = {"key": "value", "number": 42, "flag": True} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_invalid_input(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + invalid_inputs = [ + "Just a string", + "['list', 'of', 'values']", + "12345", + "", + ] + + for invalid_input in invalid_inputs: + with pytest.raises(Exception) as e_info: + tool_usage._validate_tool_input(invalid_input) + assert ( + "Tool input must be a valid dictionary in JSON or Python literal format" + in str(e_info.value) + ) + + # Test for None input separately + arguments = tool_usage._validate_tool_input(None) + assert arguments == {} # Expecting an empty dictionary + + +def test_validate_tool_input_complex_structure(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = """ + { + "user": { + "name": "Alice", + "age": 30 + }, + "items": [ + {"id": 1, "value": "Item1"}, + {"id": 2, "value": "Item2",} + ], + "active": true, + } + """ + expected_arguments = { + "user": {"name": "Alice", "age": 30}, + "items": [ + {"id": 1, "value": "Item1"}, + {"id": 2, "value": "Item2"}, + ], + "active": True, + } + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_code_content(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = '{"filename": "script.py", "content": "def hello():\\n print(\'Hello, world!\')"}' + expected_arguments = { + "filename": "script.py", + "content": "def hello():\n print('Hello, world!')", + } + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_with_escaped_quotes(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = '{"text": "He said, \\"Hello, world!\\""}' + expected_arguments = {"text": 'He said, "Hello, world!"'} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_large_json_content(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + # Simulate a large JSON content + tool_input = ( + '{"data": ' + json.dumps([{"id": i, "value": i * 2} for i in range(1000)]) + "}" + ) + expected_arguments = {"data": [{"id": i, "value": i * 2} for i in range(1000)]} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_none_input(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + arguments = tool_usage._validate_tool_input(None) + assert arguments == {} # Expecting an empty dictionary diff --git a/uv.lock b/uv.lock index f38c1d582..b925fd95f 100644 --- a/uv.lock +++ b/uv.lock @@ -659,6 +659,7 @@ dependencies = [ { name = "click" }, { name = "instructor" }, { name = "json-repair" }, + { name = "json5" }, { name = "jsonref" }, { name = "litellm" }, { name = "openai" }, @@ -737,6 +738,7 @@ requires-dist = [ { name = "fastembed", marker = "extra == 'fastembed'", specifier = ">=0.4.1" }, { name = "instructor", specifier = ">=1.3.3" }, { name = "json-repair", specifier = ">=0.25.2" }, + { name = "json5", specifier = ">=0.10.0" }, { name = "jsonref", specifier = ">=1.1.0" }, { name = "litellm", specifier = "==1.57.4" }, { name = "mem0ai", marker = "extra == 'mem0'", specifier = ">=0.1.29" }, @@ -2077,6 +2079,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/23/38/34cb843cee4c5c27aa5c822e90e99bf96feb3dfa705713b5b6e601d17f5c/json_repair-0.30.0-py3-none-any.whl", hash = "sha256:bda4a5552dc12085c6363ff5acfcdb0c9cafc629989a2112081b7e205828228d", size = 17641 }, ] +[[package]] +name = "json5" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/85/3d/bbe62f3d0c05a689c711cff57b2e3ac3d3e526380adb7c781989f075115c/json5-0.10.0.tar.gz", hash = "sha256:e66941c8f0a02026943c52c2eb34ebeb2a6f819a0be05920a6f5243cd30fd559", size = 48202 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/42/797895b952b682c3dafe23b1834507ee7f02f4d6299b65aaa61425763278/json5-0.10.0-py3-none-any.whl", hash = "sha256:19b23410220a7271e8377f81ba8aacba2fdd56947fbb137ee5977cbe1f5e8dfa", size = 34049 }, +] + [[package]] name = "jsonlines" version = "3.1.0"