diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index df51807f7..56abaae02 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -50,6 +50,7 @@ from crewai.utilities.agent_utils import ( handle_unknown_error, has_reached_max_iterations, is_context_length_exceeded, + parse_tool_call_args, process_llm_response, track_delegation_if_needed, ) @@ -894,13 +895,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ToolUsageStartedEvent, ) - if isinstance(func_args, str): - try: - args_dict = json.loads(func_args) - except json.JSONDecodeError: - args_dict = {} - else: - args_dict = func_args + args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id, original_tool) + if parse_error is not None: + return parse_error if original_tool is None: for tool in self.original_tools or []: diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index 56c4da030..e568dc0d4 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -66,6 +66,7 @@ from crewai.utilities.agent_utils import ( has_reached_max_iterations, is_context_length_exceeded, is_inside_event_loop, + parse_tool_call_args, process_llm_response, track_delegation_if_needed, ) @@ -848,13 +849,9 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin): call_id, func_name, func_args = info # Parse arguments - if isinstance(func_args, str): - try: - args_dict = json.loads(func_args) - except json.JSONDecodeError: - args_dict = {} - else: - args_dict = func_args + args_dict, parse_error = parse_tool_call_args(func_args, func_name, call_id) + if parse_error is not None: + return parse_error # Get agent_key for event tracking agent_key = getattr(self.agent, "key", "unknown") if self.agent else "unknown" diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 8a10cdfa3..88c0826a9 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -18,6 +18,7 @@ from pydantic import ( BaseModel as PydanticBaseModel, ConfigDict, Field, + ValidationError, create_model, field_validator, ) @@ -150,14 +151,37 @@ class BaseTool(BaseModel, ABC): super().model_post_init(__context) + def _validate_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: + """Validate keyword arguments against args_schema if present. + + Args: + kwargs: The keyword arguments to validate. + + Returns: + Validated (and possibly coerced) keyword arguments. + + Raises: + ValueError: If validation against args_schema fails. + """ + if kwargs and self.args_schema is not None and self.args_schema.model_fields: + try: + validated = self.args_schema.model_validate(kwargs) + return validated.model_dump() + except Exception as e: + raise ValueError( + f"Tool '{self.name}' arguments validation failed: {e}" + ) from e + return kwargs + def run( self, *args: Any, **kwargs: Any, ) -> Any: + kwargs = self._validate_kwargs(kwargs) + result = self._run(*args, **kwargs) - # If _run is async, we safely run it if asyncio.iscoroutine(result): result = asyncio.run(result) @@ -179,6 +203,7 @@ class BaseTool(BaseModel, ABC): Returns: The result of the tool execution. """ + kwargs = self._validate_kwargs(kwargs) result = await self._arun(*args, **kwargs) self.current_usage_count += 1 return result @@ -331,6 +356,8 @@ class Tool(BaseTool, Generic[P, R]): Returns: The result of the tool execution. """ + kwargs = self._validate_kwargs(kwargs) + result = self.func(*args, **kwargs) if asyncio.iscoroutine(result): @@ -361,6 +388,7 @@ class Tool(BaseTool, Generic[P, R]): Returns: The result of the tool execution. """ + kwargs = self._validate_kwargs(kwargs) result = await self._arun(*args, **kwargs) self.current_usage_count += 1 return result diff --git a/lib/crewai/src/crewai/utilities/agent_utils.py b/lib/crewai/src/crewai/utilities/agent_utils.py index 80c80dbb6..7cad2ad67 100644 --- a/lib/crewai/src/crewai/utilities/agent_utils.py +++ b/lib/crewai/src/crewai/utilities/agent_utils.py @@ -1146,6 +1146,36 @@ def extract_tool_call_info( return None +def parse_tool_call_args( + func_args: dict[str, Any] | str, + func_name: str, + call_id: str, + original_tool: Any = None, +) -> tuple[dict[str, Any], None] | tuple[None, dict[str, Any]]: + """Parse tool call arguments from a JSON string or dict. + + Returns: + ``(args_dict, None)`` on success, or ``(None, error_result)`` on + JSON parse failure where ``error_result`` is a ready-to-return dict + with the same shape as ``_execute_single_native_tool_call`` return values. + """ + if isinstance(func_args, str): + try: + return json.loads(func_args), None + except json.JSONDecodeError as e: + return None, { + "call_id": call_id, + "func_name": func_name, + "result": ( + f"Error: Failed to parse tool arguments as JSON: {e}. " + f"Please provide valid JSON arguments for the '{func_name}' tool." + ), + "from_cache": False, + "original_tool": original_tool, + } + return func_args, None + + def _setup_before_llm_call_hooks( executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None, printer: Printer, diff --git a/lib/crewai/tests/agents/test_native_tool_calling.py b/lib/crewai/tests/agents/test_native_tool_calling.py index 26b0a8e4a..558c34bb1 100644 --- a/lib/crewai/tests/agents/test_native_tool_calling.py +++ b/lib/crewai/tests/agents/test_native_tool_calling.py @@ -11,7 +11,7 @@ import os import threading import time from collections import Counter -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from pydantic import BaseModel, Field @@ -1129,3 +1129,150 @@ class TestMaxUsageCountWithNativeToolCalling: # Verify the requested calls occurred while keeping usage bounded. assert tool.current_usage_count >= 2 assert tool.current_usage_count <= tool.max_usage_count + + +# ============================================================================= +# JSON Parse Error Handling Tests +# ============================================================================= + + +class TestNativeToolCallingJsonParseError: + """Tests that malformed JSON tool arguments produce clear errors + instead of silently dropping all arguments.""" + + def _make_executor(self, tools: list[BaseTool]) -> "CrewAgentExecutor": + """Create a minimal CrewAgentExecutor with mocked dependencies.""" + from crewai.agents.crew_agent_executor import CrewAgentExecutor + from crewai.tools.base_tool import to_langchain + + structured_tools = to_langchain(tools) + mock_agent = Mock() + mock_agent.key = "test_agent" + mock_agent.role = "tester" + mock_agent.verbose = False + mock_agent.fingerprint = None + mock_agent.tools_results = [] + + mock_task = Mock() + mock_task.name = "test" + mock_task.description = "test" + mock_task.id = "test-id" + + executor = object.__new__(CrewAgentExecutor) + executor.agent = mock_agent + executor.task = mock_task + executor.crew = Mock() + executor.tools = structured_tools + executor.original_tools = tools + executor.tools_handler = None + executor._printer = Mock() + executor.messages = [] + + return executor + + def test_malformed_json_returns_parse_error(self) -> None: + """Malformed JSON args must return a descriptive error, not silently become {}.""" + + class CodeTool(BaseTool): + name: str = "execute_code" + description: str = "Run code" + + def _run(self, code: str) -> str: + return f"ran: {code}" + + tool = CodeTool() + executor = self._make_executor([tool]) + + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + _, available_functions = convert_tools_to_openai_schema([tool]) + + malformed_json = '{"code": "print("hello")"}' + + result = executor._execute_single_native_tool_call( + call_id="call_123", + func_name="execute_code", + func_args=malformed_json, + available_functions=available_functions, + ) + + assert "Failed to parse tool arguments as JSON" in result["result"] + assert tool.current_usage_count == 0 + + def test_valid_json_still_executes_normally(self) -> None: + """Valid JSON args should execute the tool as before.""" + + class CodeTool(BaseTool): + name: str = "execute_code" + description: str = "Run code" + + def _run(self, code: str) -> str: + return f"ran: {code}" + + tool = CodeTool() + executor = self._make_executor([tool]) + + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + _, available_functions = convert_tools_to_openai_schema([tool]) + + valid_json = '{"code": "print(1)"}' + + result = executor._execute_single_native_tool_call( + call_id="call_456", + func_name="execute_code", + func_args=valid_json, + available_functions=available_functions, + ) + + assert result["result"] == "ran: print(1)" + + def test_dict_args_bypass_json_parsing(self) -> None: + """When func_args is already a dict, no JSON parsing occurs.""" + + class CodeTool(BaseTool): + name: str = "execute_code" + description: str = "Run code" + + def _run(self, code: str) -> str: + return f"ran: {code}" + + tool = CodeTool() + executor = self._make_executor([tool]) + + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + _, available_functions = convert_tools_to_openai_schema([tool]) + + result = executor._execute_single_native_tool_call( + call_id="call_789", + func_name="execute_code", + func_args={"code": "x = 42"}, + available_functions=available_functions, + ) + + assert result["result"] == "ran: x = 42" + + def test_schema_validation_catches_missing_args_on_native_path(self) -> None: + """The native function calling path should now enforce args_schema, + catching missing required fields before _run is called.""" + + class StrictTool(BaseTool): + name: str = "strict_tool" + description: str = "A tool with required args" + + def _run(self, code: str, language: str) -> str: + return f"{language}: {code}" + + tool = StrictTool() + executor = self._make_executor([tool]) + + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + _, available_functions = convert_tools_to_openai_schema([tool]) + + result = executor._execute_single_native_tool_call( + call_id="call_schema", + func_name="strict_tool", + func_args={"code": "print(1)"}, + available_functions=available_functions, + ) + + assert "Error" in result["result"] + assert "validation failed" in result["result"].lower() or "missing" in result["result"].lower() diff --git a/lib/crewai/tests/tools/test_base_tool.py b/lib/crewai/tests/tools/test_base_tool.py index 4a6850ce1..a9d3a2b6d 100644 --- a/lib/crewai/tests/tools/test_base_tool.py +++ b/lib/crewai/tests/tools/test_base_tool.py @@ -3,6 +3,8 @@ from typing import Callable from unittest.mock import patch import pytest +from pydantic import BaseModel, Field + from crewai.agent import Agent from crewai.crew import Crew from crewai.task import Task @@ -230,3 +232,204 @@ def test_max_usage_count_is_respected(): crew.kickoff() assert tool.max_usage_count == 5 assert tool.current_usage_count == 5 + + +# ============================================================================= +# Schema Validation in run() Tests +# ============================================================================= + + +class CodeExecutorInput(BaseModel): + code: str = Field(description="The code to execute") + language: str = Field(default="python", description="Programming language") + + +class CodeExecutorTool(BaseTool): + name: str = "code_executor" + description: str = "Execute code snippets" + args_schema: type[BaseModel] = CodeExecutorInput + + def _run(self, code: str, language: str = "python") -> str: + return f"Executed {language}: {code}" + + +class TestBaseToolRunValidation: + """Tests for args_schema validation in BaseTool.run().""" + + def test_run_with_valid_kwargs_passes_validation(self) -> None: + """Valid keyword arguments should pass schema validation and execute.""" + t = CodeExecutorTool() + result = t.run(code="print('hello')") + assert result == "Executed python: print('hello')" + + def test_run_with_all_kwargs_passes_validation(self) -> None: + """All keyword arguments including optional ones should pass.""" + t = CodeExecutorTool() + result = t.run(code="console.log('hi')", language="javascript") + assert result == "Executed javascript: console.log('hi')" + + def test_run_with_missing_required_kwarg_raises(self) -> None: + """Missing required kwargs should raise ValueError from schema validation.""" + t = CodeExecutorTool() + with pytest.raises(ValueError, match="validation failed"): + t.run(language="python") + + def test_run_with_wrong_field_name_raises(self) -> None: + """Kwargs not matching any schema field should trigger validation error + for missing required fields.""" + t = CodeExecutorTool() + with pytest.raises(ValueError, match="validation failed"): + t.run(wrong_arg="value") + + def test_run_with_positional_args_skips_validation(self) -> None: + """Positional-arg calls should bypass schema validation (backwards compat).""" + class SimpleTool(BaseTool): + name: str = "simple" + description: str = "A simple tool" + + def _run(self, question: str) -> str: + return question + + t = SimpleTool() + result = t.run("What is life?") + assert result == "What is life?" + + def test_run_strips_extra_kwargs_from_llm(self) -> None: + """Extra kwargs not in the schema should be silently stripped, + preventing unexpected-keyword crashes in _run.""" + t = CodeExecutorTool() + result = t.run(code="1+1", extra_hallucinated_field="junk") + assert result == "Executed python: 1+1" + + def test_run_increments_usage_after_validation(self) -> None: + """Usage count should still increment after validated execution.""" + t = CodeExecutorTool() + assert t.current_usage_count == 0 + t.run(code="x = 1") + assert t.current_usage_count == 1 + + def test_run_does_not_increment_usage_on_validation_error(self) -> None: + """Usage count should NOT increment when validation fails.""" + t = CodeExecutorTool() + assert t.current_usage_count == 0 + with pytest.raises(ValueError): + t.run(wrong="bad") + assert t.current_usage_count == 0 + + +class TestToolDecoratorRunValidation: + """Tests for args_schema validation in Tool.run() (decorator-based tools).""" + + def test_decorator_tool_run_validates_kwargs(self) -> None: + """Decorator-created tools should also validate kwargs against schema.""" + @tool("execute_code") + def execute_code(code: str, language: str = "python") -> str: + """Execute a code snippet.""" + return f"Executed {language}: {code}" + + result = execute_code.run(code="x = 1") + assert result == "Executed python: x = 1" + + def test_decorator_tool_run_rejects_missing_required(self) -> None: + """Decorator tools should reject missing required args via validation.""" + @tool("execute_code") + def execute_code(code: str) -> str: + """Execute a code snippet.""" + return f"Executed: {code}" + + with pytest.raises(ValueError, match="validation failed"): + execute_code.run(wrong_arg="value") + + def test_decorator_tool_positional_args_still_work(self) -> None: + """Positional args to decorator tools should bypass validation.""" + @tool("greet") + def greet(name: str) -> str: + """Greet someone.""" + return f"Hello, {name}!" + + result = greet.run("World") + assert result == "Hello, World!" + + +# ============================================================================= +# Async arun() Schema Validation Tests +# ============================================================================= + + +class AsyncCodeExecutorTool(BaseTool): + name: str = "async_code_executor" + description: str = "Execute code snippets asynchronously" + args_schema: type[BaseModel] = CodeExecutorInput + + async def _arun(self, code: str, language: str = "python") -> str: + return f"Async executed {language}: {code}" + + def _run(self, code: str, language: str = "python") -> str: + return f"Executed {language}: {code}" + + +class TestBaseToolArunValidation: + """Tests for args_schema validation in BaseTool.arun().""" + + @pytest.mark.asyncio + async def test_arun_with_valid_kwargs_passes_validation(self) -> None: + """Valid keyword arguments should pass schema validation in arun.""" + t = AsyncCodeExecutorTool() + result = await t.arun(code="print('hello')") + assert result == "Async executed python: print('hello')" + + @pytest.mark.asyncio + async def test_arun_with_missing_required_kwarg_raises(self) -> None: + """Missing required kwargs should raise ValueError in arun.""" + t = AsyncCodeExecutorTool() + with pytest.raises(ValueError, match="validation failed"): + await t.arun(language="python") + + @pytest.mark.asyncio + async def test_arun_with_wrong_field_name_raises(self) -> None: + """Kwargs not matching schema fields should trigger validation error in arun.""" + t = AsyncCodeExecutorTool() + with pytest.raises(ValueError, match="validation failed"): + await t.arun(wrong_arg="value") + + @pytest.mark.asyncio + async def test_arun_strips_extra_kwargs(self) -> None: + """Extra kwargs not in the schema should be stripped in arun.""" + t = AsyncCodeExecutorTool() + result = await t.arun(code="1+1", extra_field="junk") + assert result == "Async executed python: 1+1" + + @pytest.mark.asyncio + async def test_arun_does_not_increment_usage_on_validation_error(self) -> None: + """Usage count should NOT increment when arun validation fails.""" + t = AsyncCodeExecutorTool() + assert t.current_usage_count == 0 + with pytest.raises(ValueError): + await t.arun(wrong="bad") + assert t.current_usage_count == 0 + + +class TestToolDecoratorArunValidation: + """Tests for args_schema validation in Tool.arun() (decorator-based async tools).""" + + @pytest.mark.asyncio + async def test_async_decorator_tool_arun_validates_kwargs(self) -> None: + """Async decorator tools should validate kwargs in arun.""" + @tool("async_execute") + async def async_execute(code: str, language: str = "python") -> str: + """Execute code asynchronously.""" + return f"Async {language}: {code}" + + result = await async_execute.arun(code="x = 1") + assert result == "Async python: x = 1" + + @pytest.mark.asyncio + async def test_async_decorator_tool_arun_rejects_missing_required(self) -> None: + """Async decorator tools should reject missing required args in arun.""" + @tool("async_execute") + async def async_execute(code: str) -> str: + """Execute code asynchronously.""" + return f"Async: {code}" + + with pytest.raises(ValueError, match="validation failed"): + await async_execute.arun(wrong_arg="value") diff --git a/lib/crewai/tests/utilities/test_agent_utils.py b/lib/crewai/tests/utilities/test_agent_utils.py index 31d7b9705..8e3093219 100644 --- a/lib/crewai/tests/utilities/test_agent_utils.py +++ b/lib/crewai/tests/utilities/test_agent_utils.py @@ -17,6 +17,7 @@ from crewai.utilities.agent_utils import ( _format_messages_for_summary, _split_messages_into_chunks, convert_tools_to_openai_schema, + parse_tool_call_args, summarize_messages, ) @@ -922,3 +923,56 @@ class TestParallelSummarizationVCR: assert summary_msg["role"] == "user" assert "files" in summary_msg assert "report.pdf" in summary_msg["files"] + + +class TestParseToolCallArgs: + """Unit tests for parse_tool_call_args.""" + + def test_valid_json_string_returns_dict(self) -> None: + args_dict, error = parse_tool_call_args('{"code": "print(1)"}', "run_code", "call_1") + assert error is None + assert args_dict == {"code": "print(1)"} + + def test_malformed_json_returns_error_dict(self) -> None: + args_dict, error = parse_tool_call_args('{"code": "print("hi")"}', "run_code", "call_1") + assert args_dict is None + assert error is not None + assert error["call_id"] == "call_1" + assert error["func_name"] == "run_code" + assert error["from_cache"] is False + assert "Failed to parse tool arguments as JSON" in error["result"] + assert "run_code" in error["result"] + + def test_malformed_json_preserves_original_tool(self) -> None: + mock_tool = object() + _, error = parse_tool_call_args("{bad}", "my_tool", "call_2", original_tool=mock_tool) + assert error is not None + assert error["original_tool"] is mock_tool + + def test_malformed_json_original_tool_defaults_to_none(self) -> None: + _, error = parse_tool_call_args("{bad}", "my_tool", "call_3") + assert error is not None + assert error["original_tool"] is None + + def test_dict_input_returned_directly(self) -> None: + func_args = {"code": "x = 42"} + args_dict, error = parse_tool_call_args(func_args, "run_code", "call_4") + assert error is None + assert args_dict == {"code": "x = 42"} + + def test_empty_dict_input_returned_directly(self) -> None: + args_dict, error = parse_tool_call_args({}, "run_code", "call_5") + assert error is None + assert args_dict == {} + + def test_valid_json_with_nested_values(self) -> None: + args_dict, error = parse_tool_call_args( + '{"query": "hello", "options": {"limit": 10}}', "search", "call_6" + ) + assert error is None + assert args_dict == {"query": "hello", "options": {"limit": 10}} + + def test_error_result_has_correct_keys(self) -> None: + _, error = parse_tool_call_args("{bad json}", "tool", "call_7") + assert error is not None + assert set(error.keys()) == {"call_id", "func_name", "result", "from_cache", "original_tool"}