diff --git a/src/crewai/tools/base_tool.py b/src/crewai/tools/base_tool.py index 3571eda30..5cd17d832 100644 --- a/src/crewai/tools/base_tool.py +++ b/src/crewai/tools/base_tool.py @@ -99,10 +99,19 @@ class BaseTool(BaseModel, ABC): Raises: ValueError: If input is neither a string nor a dictionary + ValueError: If input exceeds the maximum allowed size + ValueError: If input contains nested dictionaries beyond the maximum allowed depth """ + # Input type validation if not isinstance(input, (str, dict)): raise ValueError(f"Input must be string or dict, got {type(input)}") + # Input size validation (limit to 100KB) + MAX_INPUT_SIZE = 100 * 1024 # 100KB + if isinstance(input, str) and len(input.encode('utf-8')) > MAX_INPUT_SIZE: + logger.warning(f"Input string exceeds maximum size of {MAX_INPUT_SIZE} bytes") + raise ValueError(f"Input string exceeds maximum size of {MAX_INPUT_SIZE} bytes") + if isinstance(input, str): # Try to parse as JSON if it's a string try: @@ -118,6 +127,21 @@ class BaseTool(BaseModel, ABC): logger.debug(f"Using non-dict input directly: {input}") return self._run(input) + # Validate nested dictionary depth + MAX_DEPTH = 5 + def check_depth(obj, current_depth=1): + if current_depth > MAX_DEPTH: + return False + if isinstance(obj, dict): + return all(check_depth(v, current_depth + 1) for v in obj.values()) + elif isinstance(obj, (list, tuple)): + return all(check_depth(item, current_depth + 1) for item in obj) + return True + + if not check_depth(input): + logger.warning(f"Input contains nested structures beyond maximum depth of {MAX_DEPTH}") + raise ValueError(f"Input contains nested structures beyond maximum depth of {MAX_DEPTH}") + # Get the expected arguments from the schema if hasattr(self, 'args_schema') and self.args_schema is not None: try: diff --git a/tests/tools/test_invoke_method_additional.py b/tests/tools/test_invoke_method_additional.py new file mode 100644 index 000000000..2c1bfe76d --- /dev/null +++ b/tests/tools/test_invoke_method_additional.py @@ -0,0 +1,69 @@ +from typing import Type + +import pytest +from pydantic import BaseModel, Field + +from crewai.tools import BaseTool + + +class TestToolInput(BaseModel): + param: str = Field(description="A test parameter") + + +class TestTool(BaseTool): + name: str = "Test Tool" + description: str = "A tool for testing the invoke method" + args_schema: Type[BaseModel] = TestToolInput + + def _run(self, param: str) -> str: + return f"Tool executed with: {param}" + + +def test_invoke_with_invalid_type(): + """Test that invoke raises ValueError with invalid input types.""" + tool = TestTool() + with pytest.raises(ValueError, match="Input must be string or dict"): + tool.invoke(input=123) + + with pytest.raises(ValueError, match="Input must be string or dict"): + tool.invoke(input=["list", "not", "allowed"]) + + with pytest.raises(ValueError, match="Input must be string or dict"): + tool.invoke(input=None) + + +def test_invoke_with_config(): + """Test that invoke properly handles configuration dictionaries.""" + tool = TestTool() + # Config should be passed through to _run but not affect the result + result = tool.invoke(input={"param": "test with config"}, config={"timeout": 30}) + assert result == "Tool executed with: test with config" + + +def test_invoke_with_malformed_json(): + """Test that invoke handles malformed JSON gracefully.""" + tool = TestTool() + # Malformed JSON should be treated as a raw string + result = tool.invoke(input="{param: this is not valid JSON}") + assert "this is not valid JSON" in result + + +def test_invoke_with_nested_dict(): + """Test that invoke handles nested dictionaries properly.""" + class NestedToolInput(BaseModel): + config: dict = Field(description="A nested configuration dictionary") + + class NestedTool(BaseTool): + name: str = "Nested Tool" + description: str = "A tool for testing nested dictionaries" + args_schema: Type[BaseModel] = NestedToolInput + + def _run(self, config: dict) -> str: + return f"Tool executed with nested config: {config}" + + tool = NestedTool() + nested_input = {"config": {"key1": "value1", "key2": {"nested": "value"}}} + result = tool.invoke(input=nested_input) + assert "Tool executed with nested config" in result + assert "key1" in result + assert "nested" in result