Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
945a1346a3 Enhance invoke method with input sanitization, size limits, and nested dictionary depth validation
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-17 02:55:38 +00:00
Devin AI
c0386b73b9 Enhance invoke method with better error handling, logging, and input validation
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-17 02:51:46 +00:00
Devin AI
3f25e535f4 Fix issue #2383: Add invoke method to BaseTool for models without function calling support
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-03-17 02:49:10 +00:00
3 changed files with 216 additions and 1 deletions

View File

@@ -1,7 +1,11 @@
import json
import logging
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from inspect import signature from inspect import signature
from typing import Any, Callable, Type, get_args, get_origin from typing import Any, Callable, Dict, Optional, Type, Union, get_args, get_origin
logger = logging.getLogger(__name__)
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
@@ -75,6 +79,93 @@ class BaseTool(BaseModel, ABC):
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Here goes the actual implementation of the tool.""" """Here goes the actual implementation of the tool."""
def invoke(
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
) -> Any:
"""Main method for tool execution.
This method provides a fallback implementation for models that don't support
function calling natively (like QwQ-32B-Preview and deepseek-chat).
It parses the input and calls the _run method with the appropriate arguments.
Args:
input: Either a string (raw or JSON) or a dictionary of arguments
config: Optional configuration dictionary
**kwargs: Additional keyword arguments to pass to _run
Returns:
The result of calling the tool's _run method
Raises:
ValueError: If input is neither a string nor a dictionary
ValueError: If input exceeds the maximum allowed size
ValueError: If input contains nested dictionaries beyond the maximum allowed depth
"""
# Input type validation
if not isinstance(input, (str, dict)):
raise ValueError(f"Input must be string or dict, got {type(input)}")
# Input size validation (limit to 100KB)
MAX_INPUT_SIZE = 100 * 1024 # 100KB
if isinstance(input, str) and len(input.encode('utf-8')) > MAX_INPUT_SIZE:
logger.warning(f"Input string exceeds maximum size of {MAX_INPUT_SIZE} bytes")
raise ValueError(f"Input string exceeds maximum size of {MAX_INPUT_SIZE} bytes")
if isinstance(input, str):
# Try to parse as JSON if it's a string
try:
input = json.loads(input)
logger.debug(f"Successfully parsed JSON input: {input}")
except json.JSONDecodeError as e:
# If not valid JSON, pass as a single argument
logger.debug(f"Input string is not JSON format: {e}")
return self._run(input)
if not isinstance(input, dict):
# If input is not a dict after parsing, pass it directly
logger.debug(f"Using non-dict input directly: {input}")
return self._run(input)
# Validate nested dictionary depth
MAX_DEPTH = 5
def check_depth(obj, current_depth=1):
if current_depth > MAX_DEPTH:
return False
if isinstance(obj, dict):
return all(check_depth(v, current_depth + 1) for v in obj.values())
elif isinstance(obj, (list, tuple)):
return all(check_depth(item, current_depth + 1) for item in obj)
return True
if not check_depth(input):
logger.warning(f"Input contains nested structures beyond maximum depth of {MAX_DEPTH}")
raise ValueError(f"Input contains nested structures beyond maximum depth of {MAX_DEPTH}")
# Get the expected arguments from the schema
if hasattr(self, 'args_schema') and self.args_schema is not None:
try:
# Extract argument names from the schema
arg_names = list(self.args_schema.model_json_schema()["properties"].keys())
# Filter the input to only include valid arguments
filtered_args = {}
for k in input.keys():
if k in arg_names:
filtered_args[k] = input[k]
else:
logger.warning(f"Ignoring unexpected argument: {k}")
logger.debug(f"Calling _run with filtered arguments: {filtered_args}")
# Call _run with the filtered arguments
return self._run(**filtered_args)
except Exception as e:
# Fallback to passing the entire input dict if schema parsing fails
logger.warning(f"Schema parsing failed, using raw input: {e}")
# If we couldn't parse the schema or there was an error, just pass the input dict
logger.debug(f"Calling _run with unfiltered arguments: {input}")
return self._run(**input)
def to_structured_tool(self) -> CrewStructuredTool: def to_structured_tool(self) -> CrewStructuredTool:
"""Convert this tool to a CrewStructuredTool instance.""" """Convert this tool to a CrewStructuredTool instance."""

View File

@@ -0,0 +1,55 @@
from typing import Type
import pytest
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
class TestToolInput(BaseModel):
param: str = Field(description="A test parameter")
class TestTool(BaseTool):
name: str = "Test Tool"
description: str = "A tool for testing the invoke method"
args_schema: Type[BaseModel] = TestToolInput
def _run(self, param: str) -> str:
return f"Tool executed with: {param}"
def test_invoke_with_dict():
"""Test that invoke works with a dictionary input."""
tool = TestTool()
result = tool.invoke(input={"param": "test value"})
assert result == "Tool executed with: test value"
def test_invoke_with_json_string():
"""Test that invoke works with a JSON string input."""
tool = TestTool()
result = tool.invoke(input='{"param": "test value"}')
assert result == "Tool executed with: test value"
def test_invoke_with_raw_string():
"""Test that invoke works with a raw string input."""
tool = TestTool()
result = tool.invoke(input="test value")
assert result == "Tool executed with: test value"
def test_invoke_with_empty_dict():
"""Test that invoke handles empty dict input appropriately."""
tool = TestTool()
with pytest.raises(Exception):
# Should raise an exception since param is required
tool.invoke(input={})
def test_invoke_with_extra_args():
"""Test that invoke filters out extra arguments not in the schema."""
tool = TestTool()
result = tool.invoke(input={"param": "test value", "extra": "ignored"})
assert result == "Tool executed with: test value"

View File

@@ -0,0 +1,69 @@
from typing import Type
import pytest
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
class TestToolInput(BaseModel):
param: str = Field(description="A test parameter")
class TestTool(BaseTool):
name: str = "Test Tool"
description: str = "A tool for testing the invoke method"
args_schema: Type[BaseModel] = TestToolInput
def _run(self, param: str) -> str:
return f"Tool executed with: {param}"
def test_invoke_with_invalid_type():
"""Test that invoke raises ValueError with invalid input types."""
tool = TestTool()
with pytest.raises(ValueError, match="Input must be string or dict"):
tool.invoke(input=123)
with pytest.raises(ValueError, match="Input must be string or dict"):
tool.invoke(input=["list", "not", "allowed"])
with pytest.raises(ValueError, match="Input must be string or dict"):
tool.invoke(input=None)
def test_invoke_with_config():
"""Test that invoke properly handles configuration dictionaries."""
tool = TestTool()
# Config should be passed through to _run but not affect the result
result = tool.invoke(input={"param": "test with config"}, config={"timeout": 30})
assert result == "Tool executed with: test with config"
def test_invoke_with_malformed_json():
"""Test that invoke handles malformed JSON gracefully."""
tool = TestTool()
# Malformed JSON should be treated as a raw string
result = tool.invoke(input="{param: this is not valid JSON}")
assert "this is not valid JSON" in result
def test_invoke_with_nested_dict():
"""Test that invoke handles nested dictionaries properly."""
class NestedToolInput(BaseModel):
config: dict = Field(description="A nested configuration dictionary")
class NestedTool(BaseTool):
name: str = "Nested Tool"
description: str = "A tool for testing nested dictionaries"
args_schema: Type[BaseModel] = NestedToolInput
def _run(self, config: dict) -> str:
return f"Tool executed with nested config: {config}"
tool = NestedTool()
nested_input = {"config": {"key1": "value1", "key2": {"nested": "value"}}}
result = tool.invoke(input=nested_input)
assert "Tool executed with nested config" in result
assert "key1" in result
assert "nested" in result