mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Enhance invoke method with input sanitization, size limits, and nested dictionary depth validation
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -99,10 +99,19 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
Raises:
|
||||
ValueError: If input is neither a string nor a dictionary
|
||||
ValueError: If input exceeds the maximum allowed size
|
||||
ValueError: If input contains nested dictionaries beyond the maximum allowed depth
|
||||
"""
|
||||
# Input type validation
|
||||
if not isinstance(input, (str, dict)):
|
||||
raise ValueError(f"Input must be string or dict, got {type(input)}")
|
||||
|
||||
# Input size validation (limit to 100KB)
|
||||
MAX_INPUT_SIZE = 100 * 1024 # 100KB
|
||||
if isinstance(input, str) and len(input.encode('utf-8')) > MAX_INPUT_SIZE:
|
||||
logger.warning(f"Input string exceeds maximum size of {MAX_INPUT_SIZE} bytes")
|
||||
raise ValueError(f"Input string exceeds maximum size of {MAX_INPUT_SIZE} bytes")
|
||||
|
||||
if isinstance(input, str):
|
||||
# Try to parse as JSON if it's a string
|
||||
try:
|
||||
@@ -118,6 +127,21 @@ class BaseTool(BaseModel, ABC):
|
||||
logger.debug(f"Using non-dict input directly: {input}")
|
||||
return self._run(input)
|
||||
|
||||
# Validate nested dictionary depth
|
||||
MAX_DEPTH = 5
|
||||
def check_depth(obj, current_depth=1):
|
||||
if current_depth > MAX_DEPTH:
|
||||
return False
|
||||
if isinstance(obj, dict):
|
||||
return all(check_depth(v, current_depth + 1) for v in obj.values())
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return all(check_depth(item, current_depth + 1) for item in obj)
|
||||
return True
|
||||
|
||||
if not check_depth(input):
|
||||
logger.warning(f"Input contains nested structures beyond maximum depth of {MAX_DEPTH}")
|
||||
raise ValueError(f"Input contains nested structures beyond maximum depth of {MAX_DEPTH}")
|
||||
|
||||
# Get the expected arguments from the schema
|
||||
if hasattr(self, 'args_schema') and self.args_schema is not None:
|
||||
try:
|
||||
|
||||
69
tests/tools/test_invoke_method_additional.py
Normal file
69
tests/tools/test_invoke_method_additional.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class TestToolInput(BaseModel):
|
||||
param: str = Field(description="A test parameter")
|
||||
|
||||
|
||||
class TestTool(BaseTool):
|
||||
name: str = "Test Tool"
|
||||
description: str = "A tool for testing the invoke method"
|
||||
args_schema: Type[BaseModel] = TestToolInput
|
||||
|
||||
def _run(self, param: str) -> str:
|
||||
return f"Tool executed with: {param}"
|
||||
|
||||
|
||||
def test_invoke_with_invalid_type():
|
||||
"""Test that invoke raises ValueError with invalid input types."""
|
||||
tool = TestTool()
|
||||
with pytest.raises(ValueError, match="Input must be string or dict"):
|
||||
tool.invoke(input=123)
|
||||
|
||||
with pytest.raises(ValueError, match="Input must be string or dict"):
|
||||
tool.invoke(input=["list", "not", "allowed"])
|
||||
|
||||
with pytest.raises(ValueError, match="Input must be string or dict"):
|
||||
tool.invoke(input=None)
|
||||
|
||||
|
||||
def test_invoke_with_config():
|
||||
"""Test that invoke properly handles configuration dictionaries."""
|
||||
tool = TestTool()
|
||||
# Config should be passed through to _run but not affect the result
|
||||
result = tool.invoke(input={"param": "test with config"}, config={"timeout": 30})
|
||||
assert result == "Tool executed with: test with config"
|
||||
|
||||
|
||||
def test_invoke_with_malformed_json():
|
||||
"""Test that invoke handles malformed JSON gracefully."""
|
||||
tool = TestTool()
|
||||
# Malformed JSON should be treated as a raw string
|
||||
result = tool.invoke(input="{param: this is not valid JSON}")
|
||||
assert "this is not valid JSON" in result
|
||||
|
||||
|
||||
def test_invoke_with_nested_dict():
|
||||
"""Test that invoke handles nested dictionaries properly."""
|
||||
class NestedToolInput(BaseModel):
|
||||
config: dict = Field(description="A nested configuration dictionary")
|
||||
|
||||
class NestedTool(BaseTool):
|
||||
name: str = "Nested Tool"
|
||||
description: str = "A tool for testing nested dictionaries"
|
||||
args_schema: Type[BaseModel] = NestedToolInput
|
||||
|
||||
def _run(self, config: dict) -> str:
|
||||
return f"Tool executed with nested config: {config}"
|
||||
|
||||
tool = NestedTool()
|
||||
nested_input = {"config": {"key1": "value1", "key2": {"nested": "value"}}}
|
||||
result = tool.invoke(input=nested_input)
|
||||
assert "Tool executed with nested config" in result
|
||||
assert "key1" in result
|
||||
assert "nested" in result
|
||||
Reference in New Issue
Block a user