diff --git a/pyproject.toml b/pyproject.toml index 8dbe56fd8..4a9343697 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "tomli-w>=1.1.0", "tomli>=2.0.2", "blinker>=1.9.0", + "json5>=0.10.0", ] [project.urls] diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 82e8d084b..dec0effd7 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -1,15 +1,12 @@ -import os import shutil import subprocess from typing import Any, Dict, List, Literal, Optional, Union -from litellm import AuthenticationError as LiteLLMAuthenticationError from pydantic import Field, InstanceOf, PrivateAttr, model_validator from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor -from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context @@ -257,8 +254,8 @@ class Agent(BaseAgent): } )["output"] except Exception as e: - if isinstance(e, LiteLLMAuthenticationError): - # Do not retry on authentication errors + if e.__class__.__module__.startswith("litellm"): + # Do not retry on litellm errors raise e self._times_executed += 1 if self._times_executed > self.max_retry_limit: diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index d7bf97795..b9797193c 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -3,8 +3,6 @@ import re from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union -from litellm.exceptions import AuthenticationError as LiteLLMAuthenticationError - from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.parser import ( @@ -103,7 +101,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): try: formatted_answer = self._invoke_loop() except Exception as e: - raise e + if e.__class__.__module__.startswith("litellm"): + # Do not retry on litellm errors + raise e + else: + self._handle_unknown_error(e) + raise e if self.ask_for_human_input: formatted_answer = self._handle_human_feedback(formatted_answer) @@ -146,6 +149,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): formatted_answer = self._handle_output_parser_exception(e) except Exception as e: + if e.__class__.__module__.startswith("litellm"): + # Do not retry on litellm errors + raise e if self._is_context_length_exceeded(e): self._handle_context_length() continue diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 334759a6d..761cc52ad 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -350,7 +350,10 @@ def chat(): Start a conversation with the Crew, collecting user-supplied inputs, and using the Chat LLM to generate responses. """ - click.echo("Starting a conversation with the Crew") + click.secho( + "\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n", + ) + run_chat() diff --git a/src/crewai/cli/crew_chat.py b/src/crewai/cli/crew_chat.py index f1695f0a4..cd0da2bb8 100644 --- a/src/crewai/cli/crew_chat.py +++ b/src/crewai/cli/crew_chat.py @@ -1,6 +1,9 @@ import json +import platform import re import sys +import threading +import time from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple @@ -18,27 +21,29 @@ from crewai.utilities.llm_utils import create_llm MIN_REQUIRED_VERSION = "0.98.0" -def check_conversational_crews_version(crewai_version: str, pyproject_data: dict) -> bool: +def check_conversational_crews_version( + crewai_version: str, pyproject_data: dict +) -> bool: """ Check if the installed crewAI version supports conversational crews. - + Args: - crewai_version: The current version of crewAI - pyproject_data: Dictionary containing pyproject.toml data - + crewai_version: The current version of crewAI. + pyproject_data: Dictionary containing pyproject.toml data. + Returns: - bool: True if version check passes, False otherwise + bool: True if version check passes, False otherwise. """ try: if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION): click.secho( "You are using an older version of crewAI that doesn't support conversational crews. " "Run 'uv upgrade crewai' to get the latest version.", - fg="red" + fg="red", ) return False except version.InvalidVersion: - click.secho("Invalid crewAI version format detected", fg="red") + click.secho("Invalid crewAI version format detected.", fg="red") return False return True @@ -54,20 +59,42 @@ def run_chat(): if not check_conversational_crews_version(crewai_version, pyproject_data): return + crew, crew_name = load_crew_and_name() chat_llm = initialize_chat_llm(crew) if not chat_llm: return - crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm) - crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs) - system_message = build_system_message(crew_chat_inputs) - - # Call the LLM to generate the introductory message - introductory_message = chat_llm.call( - messages=[{"role": "system", "content": system_message}] + # Indicate that the crew is being analyzed + click.secho( + "\nAnalyzing crew and required inputs - this may take 3 to 30 seconds " + "depending on the complexity of your crew.", + fg="white", ) - click.secho(f"\nAssistant: {introductory_message}\n", fg="green") + + # Start loading indicator + loading_complete = threading.Event() + loading_thread = threading.Thread(target=show_loading, args=(loading_complete,)) + loading_thread.start() + + try: + crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm) + crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs) + system_message = build_system_message(crew_chat_inputs) + + # Call the LLM to generate the introductory message + introductory_message = chat_llm.call( + messages=[{"role": "system", "content": system_message}] + ) + finally: + # Stop loading indicator + loading_complete.set() + loading_thread.join() + + # Indicate that the analysis is complete + click.secho("\nFinished analyzing crew.\n", fg="white") + + click.secho(f"Assistant: {introductory_message}\n", fg="green") messages = [ {"role": "system", "content": system_message}, @@ -78,15 +105,17 @@ def run_chat(): crew_chat_inputs.crew_name: create_tool_function(crew, messages), } - click.secho( - "\nEntering an interactive chat loop with function-calling.\n" - "Type 'exit' or Ctrl+C to quit.\n", - fg="cyan", - ) - chat_loop(chat_llm, messages, crew_tool_schema, available_functions) +def show_loading(event: threading.Event): + """Display animated loading dots while processing.""" + while not event.is_set(): + print(".", end="", flush=True) + time.sleep(1) + print() + + def initialize_chat_llm(crew: Crew) -> Optional[LLM]: """Initializes the chat LLM and handles exceptions.""" try: @@ -120,7 +149,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str: "Please keep your responses concise and friendly. " "If a user asks a question outside the crew's scope, provide a brief answer and remind them of the crew's purpose. " "After calling the tool, be prepared to take user feedback and make adjustments as needed. " - "If you are ever unsure about a user's request or need clarification, ask the user for more information." + "If you are ever unsure about a user's request or need clarification, ask the user for more information. " "Before doing anything else, introduce yourself with a friendly message like: 'Hey! I'm here to help you with [crew's purpose]. Could you please provide me with [inputs] so we can get started?' " "For example: 'Hey! I'm here to help you with uncovering and reporting cutting-edge developments through thorough research and detailed analysis. Could you please provide me with a topic you're interested in? This will help us generate a comprehensive research report and detailed analysis.'" f"\nCrew Name: {crew_chat_inputs.crew_name}" @@ -137,25 +166,33 @@ def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any: return run_crew_tool_with_messages +def flush_input(): + """Flush any pending input from the user.""" + if platform.system() == "Windows": + # Windows platform + import msvcrt + + while msvcrt.kbhit(): + msvcrt.getch() + else: + # Unix-like platforms (Linux, macOS) + import termios + + termios.tcflush(sys.stdin, termios.TCIFLUSH) + + def chat_loop(chat_llm, messages, crew_tool_schema, available_functions): """Main chat loop for interacting with the user.""" while True: try: - user_input = click.prompt("You", type=str) - if user_input.strip().lower() in ["exit", "quit"]: - click.echo("Exiting chat. Goodbye!") - break + # Flush any pending input before accepting new input + flush_input() - messages.append({"role": "user", "content": user_input}) - final_response = chat_llm.call( - messages=messages, - tools=[crew_tool_schema], - available_functions=available_functions, + user_input = get_user_input() + handle_user_input( + user_input, chat_llm, messages, crew_tool_schema, available_functions ) - messages.append({"role": "assistant", "content": final_response}) - click.secho(f"\nAssistant: {final_response}\n", fg="green") - except KeyboardInterrupt: click.echo("\nExiting chat. Goodbye!") break @@ -164,6 +201,55 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions): break +def get_user_input() -> str: + """Collect multi-line user input with exit handling.""" + click.secho( + "\nYou (type your message below. Press 'Enter' twice when you're done):", + fg="blue", + ) + user_input_lines = [] + while True: + line = input() + if line.strip().lower() == "exit": + return "exit" + if line == "": + break + user_input_lines.append(line) + return "\n".join(user_input_lines) + + +def handle_user_input( + user_input: str, + chat_llm: LLM, + messages: List[Dict[str, str]], + crew_tool_schema: Dict[str, Any], + available_functions: Dict[str, Any], +) -> None: + if user_input.strip().lower() == "exit": + click.echo("Exiting chat. Goodbye!") + return + + if not user_input.strip(): + click.echo("Empty message. Please provide input or type 'exit' to quit.") + return + + messages.append({"role": "user", "content": user_input}) + + # Indicate that assistant is processing + click.echo() + click.secho("Assistant is processing your input. Please wait...", fg="green") + + # Process assistant's response + final_response = chat_llm.call( + messages=messages, + tools=[crew_tool_schema], + available_functions=available_functions, + ) + + messages.append({"role": "assistant", "content": final_response}) + click.secho(f"\nAssistant: {final_response}\n", fg="green") + + def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict: """ Dynamically build a Littellm 'function' schema for the given crew. @@ -358,10 +444,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> ): # Replace placeholders with input names task_description = placeholder_pattern.sub( - lambda m: m.group(1), task.description + lambda m: m.group(1), task.description or "" ) expected_output = placeholder_pattern.sub( - lambda m: m.group(1), task.expected_output + lambda m: m.group(1), task.expected_output or "" ) context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Expected Output: {expected_output}") @@ -372,10 +458,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> or f"{{{input_name}}}" in agent.backstory ): # Replace placeholders with input names - agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role) - agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal) + agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "") + agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "") agent_backstory = placeholder_pattern.sub( - lambda m: m.group(1), agent.backstory + lambda m: m.group(1), agent.backstory or "" ) context_texts.append(f"Agent Role: {agent_role}") context_texts.append(f"Agent Goal: {agent_goal}") @@ -416,18 +502,20 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: for task in crew.tasks: # Replace placeholders with input names task_description = placeholder_pattern.sub( - lambda m: m.group(1), task.description + lambda m: m.group(1), task.description or "" ) expected_output = placeholder_pattern.sub( - lambda m: m.group(1), task.expected_output + lambda m: m.group(1), task.expected_output or "" ) context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Expected Output: {expected_output}") for agent in crew.agents: # Replace placeholders with input names - agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role) - agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal) - agent_backstory = placeholder_pattern.sub(lambda m: m.group(1), agent.backstory) + agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "") + agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "") + agent_backstory = placeholder_pattern.sub( + lambda m: m.group(1), agent.backstory or "" + ) context_texts.append(f"Agent Role: {agent_role}") context_texts.append(f"Agent Goal: {agent_goal}") context_texts.append(f"Agent Backstory: {agent_backstory}") diff --git a/src/crewai/cli/templates/crew/.gitignore b/src/crewai/cli/templates/crew/.gitignore index d50a09fc9..7279347af 100644 --- a/src/crewai/cli/templates/crew/.gitignore +++ b/src/crewai/cli/templates/crew/.gitignore @@ -1,2 +1,3 @@ .env __pycache__/ +.DS_Store diff --git a/src/crewai/cli/templates/flow/.gitignore b/src/crewai/cli/templates/flow/.gitignore index 02dc677b9..3b6f1bec0 100644 --- a/src/crewai/cli/templates/flow/.gitignore +++ b/src/crewai/cli/templates/flow/.gitignore @@ -1,3 +1,4 @@ .env __pycache__/ lib/ +.DS_Store diff --git a/src/crewai/tools/tool_usage.py b/src/crewai/tools/tool_usage.py index a59ed7b50..218410ef7 100644 --- a/src/crewai/tools/tool_usage.py +++ b/src/crewai/tools/tool_usage.py @@ -1,12 +1,13 @@ import ast import datetime import json -import re import time from difflib import SequenceMatcher +from json import JSONDecodeError from textwrap import dedent -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union +import json5 from json_repair import repair_json import crewai.utilities.events as events @@ -407,28 +408,55 @@ class ToolUsage: ) return self._tool_calling(tool_string) - def _validate_tool_input(self, tool_input: str) -> Dict[str, Any]: + def _validate_tool_input(self, tool_input: Optional[str]) -> Dict[str, Any]: + if tool_input is None: + return {} + + if not isinstance(tool_input, str) or not tool_input.strip(): + raise Exception( + "Tool input must be a valid dictionary in JSON or Python literal format" + ) + + # Attempt 1: Parse as JSON try: - # Replace Python literals with JSON equivalents - replacements = { - r"'": '"', - r"None": "null", - r"True": "true", - r"False": "false", - } - for pattern, replacement in replacements.items(): - tool_input = re.sub(pattern, replacement, tool_input) - arguments = json.loads(tool_input) - except json.JSONDecodeError: - # Attempt to repair JSON string - repaired_input = repair_json(tool_input) - try: - arguments = json.loads(repaired_input) - except json.JSONDecodeError as e: - raise Exception(f"Invalid tool input JSON: {e}") + if isinstance(arguments, dict): + return arguments + except (JSONDecodeError, TypeError): + pass # Continue to the next parsing attempt - return arguments + # Attempt 2: Parse as Python literal + try: + arguments = ast.literal_eval(tool_input) + if isinstance(arguments, dict): + return arguments + except (ValueError, SyntaxError): + pass # Continue to the next parsing attempt + + # Attempt 3: Parse as JSON5 + try: + arguments = json5.loads(tool_input) + if isinstance(arguments, dict): + return arguments + except (JSONDecodeError, ValueError, TypeError): + pass # Continue to the next parsing attempt + + # Attempt 4: Repair JSON + try: + repaired_input = repair_json(tool_input) + self._printer.print( + content=f"Repaired JSON: {repaired_input}", color="blue" + ) + arguments = json.loads(repaired_input) + if isinstance(arguments, dict): + return arguments + except Exception as e: + self._printer.print(content=f"Failed to repair JSON: {e}", color="red") + + # If all parsing attempts fail, raise an error + raise Exception( + "Tool input must be a valid dictionary in JSON or Python literal format" + ) def on_tool_error(self, tool: Any, tool_calling: ToolCalling, e: Exception) -> None: event_data = self._prepare_event_data(tool, tool_calling) diff --git a/tests/agent_test.py b/tests/agent_test.py index fa519e7ac..b0efef82b 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1663,7 +1663,7 @@ def test_litellm_auth_error_handling(): agent=agent, ) - # Mock the LLM call to raise LiteLLMAuthenticationError + # Mock the LLM call to raise AuthenticationError with ( patch.object(LLM, "call") as mock_llm_call, pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"), @@ -1679,7 +1679,7 @@ def test_litellm_auth_error_handling(): def test_crew_agent_executor_litellm_auth_error(): """Test that CrewAgentExecutor handles LiteLLM authentication errors by raising them.""" - from litellm import AuthenticationError as LiteLLMAuthenticationError + from litellm.exceptions import AuthenticationError from crewai.agents.tools_handler import ToolsHandler from crewai.utilities import Printer @@ -1712,13 +1712,13 @@ def test_crew_agent_executor_litellm_auth_error(): tools_handler=ToolsHandler(), ) - # Mock the LLM call to raise LiteLLMAuthenticationError + # Mock the LLM call to raise AuthenticationError with ( patch.object(LLM, "call") as mock_llm_call, patch.object(Printer, "print") as mock_printer, - pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"), + pytest.raises(AuthenticationError) as exc_info, ): - mock_llm_call.side_effect = LiteLLMAuthenticationError( + mock_llm_call.side_effect = AuthenticationError( message="Invalid API key", llm_provider="openai", model="gpt-4" ) executor.invoke( @@ -1729,14 +1729,53 @@ def test_crew_agent_executor_litellm_auth_error(): } ) - # Verify error handling + # Verify error handling messages + error_message = f"Error during LLM call: {str(mock_llm_call.side_effect)}" mock_printer.assert_any_call( - content="An unknown error occurred. Please check the details below.", - color="red", - ) - mock_printer.assert_any_call( - content="Error details: litellm.AuthenticationError: Invalid API key", + content=error_message, color="red", ) + # Verify the call was only made once (no retries) mock_llm_call.assert_called_once() + + # Assert that the exception was raised and has the expected attributes + assert exc_info.type is AuthenticationError + assert "Invalid API key".lower() in exc_info.value.message.lower() + assert exc_info.value.llm_provider == "openai" + assert exc_info.value.model == "gpt-4" + + +def test_litellm_anthropic_error_handling(): + """Test that AnthropicError from LiteLLM is handled correctly and not retried.""" + from litellm.llms.anthropic.common_utils import AnthropicError + + # Create an agent with a mocked LLM that uses an Anthropic model + agent = Agent( + role="test role", + goal="test goal", + backstory="test backstory", + llm=LLM(model="claude-3.5-sonnet-20240620"), + max_retry_limit=0, + ) + + # Create a task + task = Task( + description="Test task", + expected_output="Test output", + agent=agent, + ) + + # Mock the LLM call to raise AnthropicError + with ( + patch.object(LLM, "call") as mock_llm_call, + pytest.raises(AnthropicError, match="Test Anthropic error"), + ): + mock_llm_call.side_effect = AnthropicError( + status_code=500, + message="Test Anthropic error", + ) + agent.execute_task(task) + + # Verify the LLM call was only made once (no retries) + mock_llm_call.assert_called_once() diff --git a/tests/tools/test_tool_usage.py b/tests/tools/test_tool_usage.py index 952011339..7b2ccd416 100644 --- a/tests/tools/test_tool_usage.py +++ b/tests/tools/test_tool_usage.py @@ -231,3 +231,255 @@ def test_validate_tool_input_with_special_characters(): arguments = tool_usage._validate_tool_input(tool_input) assert arguments == expected_arguments + + +def test_validate_tool_input_none_input(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + arguments = tool_usage._validate_tool_input(None) + assert arguments == {} + + +def test_validate_tool_input_valid_json(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = '{"key": "value", "number": 42, "flag": true}' + expected_arguments = {"key": "value", "number": 42, "flag": True} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_python_dict(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = "{'key': 'value', 'number': 42, 'flag': True}" + expected_arguments = {"key": "value", "number": 42, "flag": True} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_json5_unquoted_keys(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = "{key: 'value', number: 42, flag: true}" + expected_arguments = {"key": "value", "number": 42, "flag": True} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_with_trailing_commas(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = '{"key": "value", "number": 42, "flag": true,}' + expected_arguments = {"key": "value", "number": 42, "flag": True} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_invalid_input(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + invalid_inputs = [ + "Just a string", + "['list', 'of', 'values']", + "12345", + "", + ] + + for invalid_input in invalid_inputs: + with pytest.raises(Exception) as e_info: + tool_usage._validate_tool_input(invalid_input) + assert ( + "Tool input must be a valid dictionary in JSON or Python literal format" + in str(e_info.value) + ) + + # Test for None input separately + arguments = tool_usage._validate_tool_input(None) + assert arguments == {} # Expecting an empty dictionary + + +def test_validate_tool_input_complex_structure(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = """ + { + "user": { + "name": "Alice", + "age": 30 + }, + "items": [ + {"id": 1, "value": "Item1"}, + {"id": 2, "value": "Item2",} + ], + "active": true, + } + """ + expected_arguments = { + "user": {"name": "Alice", "age": 30}, + "items": [ + {"id": 1, "value": "Item1"}, + {"id": 2, "value": "Item2"}, + ], + "active": True, + } + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_code_content(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = '{"filename": "script.py", "content": "def hello():\\n print(\'Hello, world!\')"}' + expected_arguments = { + "filename": "script.py", + "content": "def hello():\n print('Hello, world!')", + } + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_with_escaped_quotes(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + tool_input = '{"text": "He said, \\"Hello, world!\\""}' + expected_arguments = {"text": 'He said, "Hello, world!"'} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_large_json_content(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + # Simulate a large JSON content + tool_input = ( + '{"data": ' + json.dumps([{"id": i, "value": i * 2} for i in range(1000)]) + "}" + ) + expected_arguments = {"data": [{"id": i, "value": i * 2} for i in range(1000)]} + + arguments = tool_usage._validate_tool_input(tool_input) + assert arguments == expected_arguments + + +def test_validate_tool_input_none_input(): + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[], + original_tools=[], + tools_description="", + tools_names="", + task=MagicMock(), + function_calling_llm=None, + agent=MagicMock(), + action=MagicMock(), + ) + + arguments = tool_usage._validate_tool_input(None) + assert arguments == {} # Expecting an empty dictionary diff --git a/uv.lock b/uv.lock index f38c1d582..b925fd95f 100644 --- a/uv.lock +++ b/uv.lock @@ -659,6 +659,7 @@ dependencies = [ { name = "click" }, { name = "instructor" }, { name = "json-repair" }, + { name = "json5" }, { name = "jsonref" }, { name = "litellm" }, { name = "openai" }, @@ -737,6 +738,7 @@ requires-dist = [ { name = "fastembed", marker = "extra == 'fastembed'", specifier = ">=0.4.1" }, { name = "instructor", specifier = ">=1.3.3" }, { name = "json-repair", specifier = ">=0.25.2" }, + { name = "json5", specifier = ">=0.10.0" }, { name = "jsonref", specifier = ">=1.1.0" }, { name = "litellm", specifier = "==1.57.4" }, { name = "mem0ai", marker = "extra == 'mem0'", specifier = ">=0.1.29" }, @@ -2077,6 +2079,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/23/38/34cb843cee4c5c27aa5c822e90e99bf96feb3dfa705713b5b6e601d17f5c/json_repair-0.30.0-py3-none-any.whl", hash = "sha256:bda4a5552dc12085c6363ff5acfcdb0c9cafc629989a2112081b7e205828228d", size = 17641 }, ] +[[package]] +name = "json5" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/85/3d/bbe62f3d0c05a689c711cff57b2e3ac3d3e526380adb7c781989f075115c/json5-0.10.0.tar.gz", hash = "sha256:e66941c8f0a02026943c52c2eb34ebeb2a6f819a0be05920a6f5243cd30fd559", size = 48202 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/42/797895b952b682c3dafe23b1834507ee7f02f4d6299b65aaa61425763278/json5-0.10.0-py3-none-any.whl", hash = "sha256:19b23410220a7271e8377f81ba8aacba2fdd56947fbb137ee5977cbe1f5e8dfa", size = 34049 }, +] + [[package]] name = "jsonlines" version = "3.1.0"