diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index df51807f7..8da511864 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -894,11 +894,16 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ToolUsageStartedEvent, ) + json_parse_error: str | None = None if isinstance(func_args, str): try: args_dict = json.loads(func_args) - except json.JSONDecodeError: + except json.JSONDecodeError as e: args_dict = {} + json_parse_error = ( + f"Error: Failed to parse tool arguments as JSON: {e}. " + f"Please provide valid JSON arguments for the '{func_name}' tool." + ) else: args_dict = func_args @@ -980,7 +985,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): color="red", ) - if hook_blocked: + if json_parse_error: + result = json_parse_error + elif hook_blocked: result = f"Tool execution blocked by hook. Tool: {func_name}" elif max_usage_reached and original_tool: result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore." diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 8a10cdfa3..e787c25c2 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -155,9 +155,17 @@ class BaseTool(BaseModel, ABC): *args: Any, **kwargs: Any, ) -> Any: + if kwargs and self.args_schema is not None and self.args_schema.model_fields: + try: + validated = self.args_schema.model_validate(kwargs) + kwargs = validated.model_dump() + except Exception as e: + raise ValueError( + f"Tool '{self.name}' arguments validation failed: {e}" + ) from e + result = self._run(*args, **kwargs) - # If _run is async, we safely run it if asyncio.iscoroutine(result): result = asyncio.run(result) @@ -331,6 +339,15 @@ class Tool(BaseTool, Generic[P, R]): Returns: The result of the tool execution. """ + if kwargs and self.args_schema is not None and self.args_schema.model_fields: + try: + validated = self.args_schema.model_validate(kwargs) + kwargs = validated.model_dump() + except Exception as e: + raise ValueError( + f"Tool '{self.name}' arguments validation failed: {e}" + ) from e + result = self.func(*args, **kwargs) if asyncio.iscoroutine(result): diff --git a/lib/crewai/tests/agents/test_native_tool_calling.py b/lib/crewai/tests/agents/test_native_tool_calling.py index 26b0a8e4a..558c34bb1 100644 --- a/lib/crewai/tests/agents/test_native_tool_calling.py +++ b/lib/crewai/tests/agents/test_native_tool_calling.py @@ -11,7 +11,7 @@ import os import threading import time from collections import Counter -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from pydantic import BaseModel, Field @@ -1129,3 +1129,150 @@ class TestMaxUsageCountWithNativeToolCalling: # Verify the requested calls occurred while keeping usage bounded. assert tool.current_usage_count >= 2 assert tool.current_usage_count <= tool.max_usage_count + + +# ============================================================================= +# JSON Parse Error Handling Tests +# ============================================================================= + + +class TestNativeToolCallingJsonParseError: + """Tests that malformed JSON tool arguments produce clear errors + instead of silently dropping all arguments.""" + + def _make_executor(self, tools: list[BaseTool]) -> "CrewAgentExecutor": + """Create a minimal CrewAgentExecutor with mocked dependencies.""" + from crewai.agents.crew_agent_executor import CrewAgentExecutor + from crewai.tools.base_tool import to_langchain + + structured_tools = to_langchain(tools) + mock_agent = Mock() + mock_agent.key = "test_agent" + mock_agent.role = "tester" + mock_agent.verbose = False + mock_agent.fingerprint = None + mock_agent.tools_results = [] + + mock_task = Mock() + mock_task.name = "test" + mock_task.description = "test" + mock_task.id = "test-id" + + executor = object.__new__(CrewAgentExecutor) + executor.agent = mock_agent + executor.task = mock_task + executor.crew = Mock() + executor.tools = structured_tools + executor.original_tools = tools + executor.tools_handler = None + executor._printer = Mock() + executor.messages = [] + + return executor + + def test_malformed_json_returns_parse_error(self) -> None: + """Malformed JSON args must return a descriptive error, not silently become {}.""" + + class CodeTool(BaseTool): + name: str = "execute_code" + description: str = "Run code" + + def _run(self, code: str) -> str: + return f"ran: {code}" + + tool = CodeTool() + executor = self._make_executor([tool]) + + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + _, available_functions = convert_tools_to_openai_schema([tool]) + + malformed_json = '{"code": "print("hello")"}' + + result = executor._execute_single_native_tool_call( + call_id="call_123", + func_name="execute_code", + func_args=malformed_json, + available_functions=available_functions, + ) + + assert "Failed to parse tool arguments as JSON" in result["result"] + assert tool.current_usage_count == 0 + + def test_valid_json_still_executes_normally(self) -> None: + """Valid JSON args should execute the tool as before.""" + + class CodeTool(BaseTool): + name: str = "execute_code" + description: str = "Run code" + + def _run(self, code: str) -> str: + return f"ran: {code}" + + tool = CodeTool() + executor = self._make_executor([tool]) + + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + _, available_functions = convert_tools_to_openai_schema([tool]) + + valid_json = '{"code": "print(1)"}' + + result = executor._execute_single_native_tool_call( + call_id="call_456", + func_name="execute_code", + func_args=valid_json, + available_functions=available_functions, + ) + + assert result["result"] == "ran: print(1)" + + def test_dict_args_bypass_json_parsing(self) -> None: + """When func_args is already a dict, no JSON parsing occurs.""" + + class CodeTool(BaseTool): + name: str = "execute_code" + description: str = "Run code" + + def _run(self, code: str) -> str: + return f"ran: {code}" + + tool = CodeTool() + executor = self._make_executor([tool]) + + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + _, available_functions = convert_tools_to_openai_schema([tool]) + + result = executor._execute_single_native_tool_call( + call_id="call_789", + func_name="execute_code", + func_args={"code": "x = 42"}, + available_functions=available_functions, + ) + + assert result["result"] == "ran: x = 42" + + def test_schema_validation_catches_missing_args_on_native_path(self) -> None: + """The native function calling path should now enforce args_schema, + catching missing required fields before _run is called.""" + + class StrictTool(BaseTool): + name: str = "strict_tool" + description: str = "A tool with required args" + + def _run(self, code: str, language: str) -> str: + return f"{language}: {code}" + + tool = StrictTool() + executor = self._make_executor([tool]) + + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + _, available_functions = convert_tools_to_openai_schema([tool]) + + result = executor._execute_single_native_tool_call( + call_id="call_schema", + func_name="strict_tool", + func_args={"code": "print(1)"}, + available_functions=available_functions, + ) + + assert "Error" in result["result"] + assert "validation failed" in result["result"].lower() or "missing" in result["result"].lower() diff --git a/lib/crewai/tests/tools/test_base_tool.py b/lib/crewai/tests/tools/test_base_tool.py index 4a6850ce1..a61851142 100644 --- a/lib/crewai/tests/tools/test_base_tool.py +++ b/lib/crewai/tests/tools/test_base_tool.py @@ -3,6 +3,8 @@ from typing import Callable from unittest.mock import patch import pytest +from pydantic import BaseModel, Field + from crewai.agent import Agent from crewai.crew import Crew from crewai.task import Task @@ -230,3 +232,120 @@ def test_max_usage_count_is_respected(): crew.kickoff() assert tool.max_usage_count == 5 assert tool.current_usage_count == 5 + + +# ============================================================================= +# Schema Validation in run() Tests +# ============================================================================= + + +class CodeExecutorInput(BaseModel): + code: str = Field(description="The code to execute") + language: str = Field(default="python", description="Programming language") + + +class CodeExecutorTool(BaseTool): + name: str = "code_executor" + description: str = "Execute code snippets" + args_schema: type[BaseModel] = CodeExecutorInput + + def _run(self, code: str, language: str = "python") -> str: + return f"Executed {language}: {code}" + + +class TestBaseToolRunValidation: + """Tests for args_schema validation in BaseTool.run().""" + + def test_run_with_valid_kwargs_passes_validation(self) -> None: + """Valid keyword arguments should pass schema validation and execute.""" + t = CodeExecutorTool() + result = t.run(code="print('hello')") + assert result == "Executed python: print('hello')" + + def test_run_with_all_kwargs_passes_validation(self) -> None: + """All keyword arguments including optional ones should pass.""" + t = CodeExecutorTool() + result = t.run(code="console.log('hi')", language="javascript") + assert result == "Executed javascript: console.log('hi')" + + def test_run_with_missing_required_kwarg_raises(self) -> None: + """Missing required kwargs should raise ValueError from schema validation.""" + t = CodeExecutorTool() + with pytest.raises(ValueError, match="validation failed"): + t.run(language="python") + + def test_run_with_wrong_field_name_raises(self) -> None: + """Kwargs not matching any schema field should trigger validation error + for missing required fields.""" + t = CodeExecutorTool() + with pytest.raises(ValueError, match="validation failed"): + t.run(wrong_arg="value") + + def test_run_with_positional_args_skips_validation(self) -> None: + """Positional-arg calls should bypass schema validation (backwards compat).""" + class SimpleTool(BaseTool): + name: str = "simple" + description: str = "A simple tool" + + def _run(self, question: str) -> str: + return question + + t = SimpleTool() + result = t.run("What is life?") + assert result == "What is life?" + + def test_run_strips_extra_kwargs_from_llm(self) -> None: + """Extra kwargs not in the schema should be silently stripped, + preventing unexpected-keyword crashes in _run.""" + t = CodeExecutorTool() + result = t.run(code="1+1", extra_hallucinated_field="junk") + assert result == "Executed python: 1+1" + + def test_run_increments_usage_after_validation(self) -> None: + """Usage count should still increment after validated execution.""" + t = CodeExecutorTool() + assert t.current_usage_count == 0 + t.run(code="x = 1") + assert t.current_usage_count == 1 + + def test_run_does_not_increment_usage_on_validation_error(self) -> None: + """Usage count should NOT increment when validation fails.""" + t = CodeExecutorTool() + assert t.current_usage_count == 0 + with pytest.raises(ValueError): + t.run(wrong="bad") + assert t.current_usage_count == 0 + + +class TestToolDecoratorRunValidation: + """Tests for args_schema validation in Tool.run() (decorator-based tools).""" + + def test_decorator_tool_run_validates_kwargs(self) -> None: + """Decorator-created tools should also validate kwargs against schema.""" + @tool("execute_code") + def execute_code(code: str, language: str = "python") -> str: + """Execute a code snippet.""" + return f"Executed {language}: {code}" + + result = execute_code.run(code="x = 1") + assert result == "Executed python: x = 1" + + def test_decorator_tool_run_rejects_missing_required(self) -> None: + """Decorator tools should reject missing required args via validation.""" + @tool("execute_code") + def execute_code(code: str) -> str: + """Execute a code snippet.""" + return f"Executed: {code}" + + with pytest.raises(ValueError, match="validation failed"): + execute_code.run(wrong_arg="value") + + def test_decorator_tool_positional_args_still_work(self) -> None: + """Positional args to decorator tools should bypass validation.""" + @tool("greet") + def greet(name: str) -> str: + """Greet someone.""" + return f"Hello, {name}!" + + result = greet.run("World") + assert result == "Hello, World!"