mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-24 06:48:24 +00:00
Compare commits
1 Commits
main
...
joaomdmour
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cccf0bffc |
@@ -894,11 +894,16 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
|
||||
json_parse_error: str | None = None
|
||||
if isinstance(func_args, str):
|
||||
try:
|
||||
args_dict = json.loads(func_args)
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError as e:
|
||||
args_dict = {}
|
||||
json_parse_error = (
|
||||
f"Error: Failed to parse tool arguments as JSON: {e}. "
|
||||
f"Please provide valid JSON arguments for the '{func_name}' tool."
|
||||
)
|
||||
else:
|
||||
args_dict = func_args
|
||||
|
||||
@@ -980,7 +985,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
color="red",
|
||||
)
|
||||
|
||||
if hook_blocked:
|
||||
if json_parse_error:
|
||||
result = json_parse_error
|
||||
elif hook_blocked:
|
||||
result = f"Tool execution blocked by hook. Tool: {func_name}"
|
||||
elif max_usage_reached and original_tool:
|
||||
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
|
||||
|
||||
@@ -155,9 +155,17 @@ class BaseTool(BaseModel, ABC):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if kwargs and self.args_schema is not None and self.args_schema.model_fields:
|
||||
try:
|
||||
validated = self.args_schema.model_validate(kwargs)
|
||||
kwargs = validated.model_dump()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Tool '{self.name}' arguments validation failed: {e}"
|
||||
) from e
|
||||
|
||||
result = self._run(*args, **kwargs)
|
||||
|
||||
# If _run is async, we safely run it
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.run(result)
|
||||
|
||||
@@ -331,6 +339,15 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
if kwargs and self.args_schema is not None and self.args_schema.model_fields:
|
||||
try:
|
||||
validated = self.args_schema.model_validate(kwargs)
|
||||
kwargs = validated.model_dump()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Tool '{self.name}' arguments validation failed: {e}"
|
||||
) from e
|
||||
|
||||
result = self.func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
|
||||
@@ -11,7 +11,7 @@ import os
|
||||
import threading
|
||||
import time
|
||||
from collections import Counter
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -1129,3 +1129,150 @@ class TestMaxUsageCountWithNativeToolCalling:
|
||||
# Verify the requested calls occurred while keeping usage bounded.
|
||||
assert tool.current_usage_count >= 2
|
||||
assert tool.current_usage_count <= tool.max_usage_count
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# JSON Parse Error Handling Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestNativeToolCallingJsonParseError:
|
||||
"""Tests that malformed JSON tool arguments produce clear errors
|
||||
instead of silently dropping all arguments."""
|
||||
|
||||
def _make_executor(self, tools: list[BaseTool]) -> "CrewAgentExecutor":
|
||||
"""Create a minimal CrewAgentExecutor with mocked dependencies."""
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.tools.base_tool import to_langchain
|
||||
|
||||
structured_tools = to_langchain(tools)
|
||||
mock_agent = Mock()
|
||||
mock_agent.key = "test_agent"
|
||||
mock_agent.role = "tester"
|
||||
mock_agent.verbose = False
|
||||
mock_agent.fingerprint = None
|
||||
mock_agent.tools_results = []
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task.name = "test"
|
||||
mock_task.description = "test"
|
||||
mock_task.id = "test-id"
|
||||
|
||||
executor = object.__new__(CrewAgentExecutor)
|
||||
executor.agent = mock_agent
|
||||
executor.task = mock_task
|
||||
executor.crew = Mock()
|
||||
executor.tools = structured_tools
|
||||
executor.original_tools = tools
|
||||
executor.tools_handler = None
|
||||
executor._printer = Mock()
|
||||
executor.messages = []
|
||||
|
||||
return executor
|
||||
|
||||
def test_malformed_json_returns_parse_error(self) -> None:
|
||||
"""Malformed JSON args must return a descriptive error, not silently become {}."""
|
||||
|
||||
class CodeTool(BaseTool):
|
||||
name: str = "execute_code"
|
||||
description: str = "Run code"
|
||||
|
||||
def _run(self, code: str) -> str:
|
||||
return f"ran: {code}"
|
||||
|
||||
tool = CodeTool()
|
||||
executor = self._make_executor([tool])
|
||||
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||
_, available_functions = convert_tools_to_openai_schema([tool])
|
||||
|
||||
malformed_json = '{"code": "print("hello")"}'
|
||||
|
||||
result = executor._execute_single_native_tool_call(
|
||||
call_id="call_123",
|
||||
func_name="execute_code",
|
||||
func_args=malformed_json,
|
||||
available_functions=available_functions,
|
||||
)
|
||||
|
||||
assert "Failed to parse tool arguments as JSON" in result["result"]
|
||||
assert tool.current_usage_count == 0
|
||||
|
||||
def test_valid_json_still_executes_normally(self) -> None:
|
||||
"""Valid JSON args should execute the tool as before."""
|
||||
|
||||
class CodeTool(BaseTool):
|
||||
name: str = "execute_code"
|
||||
description: str = "Run code"
|
||||
|
||||
def _run(self, code: str) -> str:
|
||||
return f"ran: {code}"
|
||||
|
||||
tool = CodeTool()
|
||||
executor = self._make_executor([tool])
|
||||
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||
_, available_functions = convert_tools_to_openai_schema([tool])
|
||||
|
||||
valid_json = '{"code": "print(1)"}'
|
||||
|
||||
result = executor._execute_single_native_tool_call(
|
||||
call_id="call_456",
|
||||
func_name="execute_code",
|
||||
func_args=valid_json,
|
||||
available_functions=available_functions,
|
||||
)
|
||||
|
||||
assert result["result"] == "ran: print(1)"
|
||||
|
||||
def test_dict_args_bypass_json_parsing(self) -> None:
|
||||
"""When func_args is already a dict, no JSON parsing occurs."""
|
||||
|
||||
class CodeTool(BaseTool):
|
||||
name: str = "execute_code"
|
||||
description: str = "Run code"
|
||||
|
||||
def _run(self, code: str) -> str:
|
||||
return f"ran: {code}"
|
||||
|
||||
tool = CodeTool()
|
||||
executor = self._make_executor([tool])
|
||||
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||
_, available_functions = convert_tools_to_openai_schema([tool])
|
||||
|
||||
result = executor._execute_single_native_tool_call(
|
||||
call_id="call_789",
|
||||
func_name="execute_code",
|
||||
func_args={"code": "x = 42"},
|
||||
available_functions=available_functions,
|
||||
)
|
||||
|
||||
assert result["result"] == "ran: x = 42"
|
||||
|
||||
def test_schema_validation_catches_missing_args_on_native_path(self) -> None:
|
||||
"""The native function calling path should now enforce args_schema,
|
||||
catching missing required fields before _run is called."""
|
||||
|
||||
class StrictTool(BaseTool):
|
||||
name: str = "strict_tool"
|
||||
description: str = "A tool with required args"
|
||||
|
||||
def _run(self, code: str, language: str) -> str:
|
||||
return f"{language}: {code}"
|
||||
|
||||
tool = StrictTool()
|
||||
executor = self._make_executor([tool])
|
||||
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||
_, available_functions = convert_tools_to_openai_schema([tool])
|
||||
|
||||
result = executor._execute_single_native_tool_call(
|
||||
call_id="call_schema",
|
||||
func_name="strict_tool",
|
||||
func_args={"code": "print(1)"},
|
||||
available_functions=available_functions,
|
||||
)
|
||||
|
||||
assert "Error" in result["result"]
|
||||
assert "validation failed" in result["result"].lower() or "missing" in result["result"].lower()
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
@@ -230,3 +232,120 @@ def test_max_usage_count_is_respected():
|
||||
crew.kickoff()
|
||||
assert tool.max_usage_count == 5
|
||||
assert tool.current_usage_count == 5
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schema Validation in run() Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CodeExecutorInput(BaseModel):
|
||||
code: str = Field(description="The code to execute")
|
||||
language: str = Field(default="python", description="Programming language")
|
||||
|
||||
|
||||
class CodeExecutorTool(BaseTool):
|
||||
name: str = "code_executor"
|
||||
description: str = "Execute code snippets"
|
||||
args_schema: type[BaseModel] = CodeExecutorInput
|
||||
|
||||
def _run(self, code: str, language: str = "python") -> str:
|
||||
return f"Executed {language}: {code}"
|
||||
|
||||
|
||||
class TestBaseToolRunValidation:
|
||||
"""Tests for args_schema validation in BaseTool.run()."""
|
||||
|
||||
def test_run_with_valid_kwargs_passes_validation(self) -> None:
|
||||
"""Valid keyword arguments should pass schema validation and execute."""
|
||||
t = CodeExecutorTool()
|
||||
result = t.run(code="print('hello')")
|
||||
assert result == "Executed python: print('hello')"
|
||||
|
||||
def test_run_with_all_kwargs_passes_validation(self) -> None:
|
||||
"""All keyword arguments including optional ones should pass."""
|
||||
t = CodeExecutorTool()
|
||||
result = t.run(code="console.log('hi')", language="javascript")
|
||||
assert result == "Executed javascript: console.log('hi')"
|
||||
|
||||
def test_run_with_missing_required_kwarg_raises(self) -> None:
|
||||
"""Missing required kwargs should raise ValueError from schema validation."""
|
||||
t = CodeExecutorTool()
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
t.run(language="python")
|
||||
|
||||
def test_run_with_wrong_field_name_raises(self) -> None:
|
||||
"""Kwargs not matching any schema field should trigger validation error
|
||||
for missing required fields."""
|
||||
t = CodeExecutorTool()
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
t.run(wrong_arg="value")
|
||||
|
||||
def test_run_with_positional_args_skips_validation(self) -> None:
|
||||
"""Positional-arg calls should bypass schema validation (backwards compat)."""
|
||||
class SimpleTool(BaseTool):
|
||||
name: str = "simple"
|
||||
description: str = "A simple tool"
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
|
||||
t = SimpleTool()
|
||||
result = t.run("What is life?")
|
||||
assert result == "What is life?"
|
||||
|
||||
def test_run_strips_extra_kwargs_from_llm(self) -> None:
|
||||
"""Extra kwargs not in the schema should be silently stripped,
|
||||
preventing unexpected-keyword crashes in _run."""
|
||||
t = CodeExecutorTool()
|
||||
result = t.run(code="1+1", extra_hallucinated_field="junk")
|
||||
assert result == "Executed python: 1+1"
|
||||
|
||||
def test_run_increments_usage_after_validation(self) -> None:
|
||||
"""Usage count should still increment after validated execution."""
|
||||
t = CodeExecutorTool()
|
||||
assert t.current_usage_count == 0
|
||||
t.run(code="x = 1")
|
||||
assert t.current_usage_count == 1
|
||||
|
||||
def test_run_does_not_increment_usage_on_validation_error(self) -> None:
|
||||
"""Usage count should NOT increment when validation fails."""
|
||||
t = CodeExecutorTool()
|
||||
assert t.current_usage_count == 0
|
||||
with pytest.raises(ValueError):
|
||||
t.run(wrong="bad")
|
||||
assert t.current_usage_count == 0
|
||||
|
||||
|
||||
class TestToolDecoratorRunValidation:
|
||||
"""Tests for args_schema validation in Tool.run() (decorator-based tools)."""
|
||||
|
||||
def test_decorator_tool_run_validates_kwargs(self) -> None:
|
||||
"""Decorator-created tools should also validate kwargs against schema."""
|
||||
@tool("execute_code")
|
||||
def execute_code(code: str, language: str = "python") -> str:
|
||||
"""Execute a code snippet."""
|
||||
return f"Executed {language}: {code}"
|
||||
|
||||
result = execute_code.run(code="x = 1")
|
||||
assert result == "Executed python: x = 1"
|
||||
|
||||
def test_decorator_tool_run_rejects_missing_required(self) -> None:
|
||||
"""Decorator tools should reject missing required args via validation."""
|
||||
@tool("execute_code")
|
||||
def execute_code(code: str) -> str:
|
||||
"""Execute a code snippet."""
|
||||
return f"Executed: {code}"
|
||||
|
||||
with pytest.raises(ValueError, match="validation failed"):
|
||||
execute_code.run(wrong_arg="value")
|
||||
|
||||
def test_decorator_tool_positional_args_still_work(self) -> None:
|
||||
"""Positional args to decorator tools should bypass validation."""
|
||||
@tool("greet")
|
||||
def greet(name: str) -> str:
|
||||
"""Greet someone."""
|
||||
return f"Hello, {name}!"
|
||||
|
||||
result = greet.run("World")
|
||||
assert result == "Hello, World!"
|
||||
|
||||
Reference in New Issue
Block a user