mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 17:48:13 +00:00
Compare commits
3 Commits
llm-event-
...
devin/1742
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
945a1346a3 | ||
|
|
c0386b73b9 | ||
|
|
3f25e535f4 |
@@ -1,7 +1,11 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Callable, Type, get_args, get_origin
|
from typing import Any, Callable, Dict, Optional, Type, Union, get_args, get_origin
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -75,6 +79,93 @@ class BaseTool(BaseModel, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Here goes the actual implementation of the tool."""
|
"""Here goes the actual implementation of the tool."""
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Main method for tool execution.
|
||||||
|
|
||||||
|
This method provides a fallback implementation for models that don't support
|
||||||
|
function calling natively (like QwQ-32B-Preview and deepseek-chat).
|
||||||
|
It parses the input and calls the _run method with the appropriate arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Either a string (raw or JSON) or a dictionary of arguments
|
||||||
|
config: Optional configuration dictionary
|
||||||
|
**kwargs: Additional keyword arguments to pass to _run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of calling the tool's _run method
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input is neither a string nor a dictionary
|
||||||
|
ValueError: If input exceeds the maximum allowed size
|
||||||
|
ValueError: If input contains nested dictionaries beyond the maximum allowed depth
|
||||||
|
"""
|
||||||
|
# Input type validation
|
||||||
|
if not isinstance(input, (str, dict)):
|
||||||
|
raise ValueError(f"Input must be string or dict, got {type(input)}")
|
||||||
|
|
||||||
|
# Input size validation (limit to 100KB)
|
||||||
|
MAX_INPUT_SIZE = 100 * 1024 # 100KB
|
||||||
|
if isinstance(input, str) and len(input.encode('utf-8')) > MAX_INPUT_SIZE:
|
||||||
|
logger.warning(f"Input string exceeds maximum size of {MAX_INPUT_SIZE} bytes")
|
||||||
|
raise ValueError(f"Input string exceeds maximum size of {MAX_INPUT_SIZE} bytes")
|
||||||
|
|
||||||
|
if isinstance(input, str):
|
||||||
|
# Try to parse as JSON if it's a string
|
||||||
|
try:
|
||||||
|
input = json.loads(input)
|
||||||
|
logger.debug(f"Successfully parsed JSON input: {input}")
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
# If not valid JSON, pass as a single argument
|
||||||
|
logger.debug(f"Input string is not JSON format: {e}")
|
||||||
|
return self._run(input)
|
||||||
|
|
||||||
|
if not isinstance(input, dict):
|
||||||
|
# If input is not a dict after parsing, pass it directly
|
||||||
|
logger.debug(f"Using non-dict input directly: {input}")
|
||||||
|
return self._run(input)
|
||||||
|
|
||||||
|
# Validate nested dictionary depth
|
||||||
|
MAX_DEPTH = 5
|
||||||
|
def check_depth(obj, current_depth=1):
|
||||||
|
if current_depth > MAX_DEPTH:
|
||||||
|
return False
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return all(check_depth(v, current_depth + 1) for v in obj.values())
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
return all(check_depth(item, current_depth + 1) for item in obj)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not check_depth(input):
|
||||||
|
logger.warning(f"Input contains nested structures beyond maximum depth of {MAX_DEPTH}")
|
||||||
|
raise ValueError(f"Input contains nested structures beyond maximum depth of {MAX_DEPTH}")
|
||||||
|
|
||||||
|
# Get the expected arguments from the schema
|
||||||
|
if hasattr(self, 'args_schema') and self.args_schema is not None:
|
||||||
|
try:
|
||||||
|
# Extract argument names from the schema
|
||||||
|
arg_names = list(self.args_schema.model_json_schema()["properties"].keys())
|
||||||
|
|
||||||
|
# Filter the input to only include valid arguments
|
||||||
|
filtered_args = {}
|
||||||
|
for k in input.keys():
|
||||||
|
if k in arg_names:
|
||||||
|
filtered_args[k] = input[k]
|
||||||
|
else:
|
||||||
|
logger.warning(f"Ignoring unexpected argument: {k}")
|
||||||
|
|
||||||
|
logger.debug(f"Calling _run with filtered arguments: {filtered_args}")
|
||||||
|
# Call _run with the filtered arguments
|
||||||
|
return self._run(**filtered_args)
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback to passing the entire input dict if schema parsing fails
|
||||||
|
logger.warning(f"Schema parsing failed, using raw input: {e}")
|
||||||
|
|
||||||
|
# If we couldn't parse the schema or there was an error, just pass the input dict
|
||||||
|
logger.debug(f"Calling _run with unfiltered arguments: {input}")
|
||||||
|
return self._run(**input)
|
||||||
|
|
||||||
def to_structured_tool(self) -> CrewStructuredTool:
|
def to_structured_tool(self) -> CrewStructuredTool:
|
||||||
"""Convert this tool to a CrewStructuredTool instance."""
|
"""Convert this tool to a CrewStructuredTool instance."""
|
||||||
|
|||||||
55
tests/tools/test_invoke_method.py
Normal file
55
tests/tools/test_invoke_method.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolInput(BaseModel):
|
||||||
|
param: str = Field(description="A test parameter")
|
||||||
|
|
||||||
|
|
||||||
|
class TestTool(BaseTool):
|
||||||
|
name: str = "Test Tool"
|
||||||
|
description: str = "A tool for testing the invoke method"
|
||||||
|
args_schema: Type[BaseModel] = TestToolInput
|
||||||
|
|
||||||
|
def _run(self, param: str) -> str:
|
||||||
|
return f"Tool executed with: {param}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_with_dict():
|
||||||
|
"""Test that invoke works with a dictionary input."""
|
||||||
|
tool = TestTool()
|
||||||
|
result = tool.invoke(input={"param": "test value"})
|
||||||
|
assert result == "Tool executed with: test value"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_with_json_string():
|
||||||
|
"""Test that invoke works with a JSON string input."""
|
||||||
|
tool = TestTool()
|
||||||
|
result = tool.invoke(input='{"param": "test value"}')
|
||||||
|
assert result == "Tool executed with: test value"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_with_raw_string():
|
||||||
|
"""Test that invoke works with a raw string input."""
|
||||||
|
tool = TestTool()
|
||||||
|
result = tool.invoke(input="test value")
|
||||||
|
assert result == "Tool executed with: test value"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_with_empty_dict():
|
||||||
|
"""Test that invoke handles empty dict input appropriately."""
|
||||||
|
tool = TestTool()
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
# Should raise an exception since param is required
|
||||||
|
tool.invoke(input={})
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_with_extra_args():
|
||||||
|
"""Test that invoke filters out extra arguments not in the schema."""
|
||||||
|
tool = TestTool()
|
||||||
|
result = tool.invoke(input={"param": "test value", "extra": "ignored"})
|
||||||
|
assert result == "Tool executed with: test value"
|
||||||
69
tests/tools/test_invoke_method_additional.py
Normal file
69
tests/tools/test_invoke_method_additional.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolInput(BaseModel):
|
||||||
|
param: str = Field(description="A test parameter")
|
||||||
|
|
||||||
|
|
||||||
|
class TestTool(BaseTool):
|
||||||
|
name: str = "Test Tool"
|
||||||
|
description: str = "A tool for testing the invoke method"
|
||||||
|
args_schema: Type[BaseModel] = TestToolInput
|
||||||
|
|
||||||
|
def _run(self, param: str) -> str:
|
||||||
|
return f"Tool executed with: {param}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_with_invalid_type():
|
||||||
|
"""Test that invoke raises ValueError with invalid input types."""
|
||||||
|
tool = TestTool()
|
||||||
|
with pytest.raises(ValueError, match="Input must be string or dict"):
|
||||||
|
tool.invoke(input=123)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Input must be string or dict"):
|
||||||
|
tool.invoke(input=["list", "not", "allowed"])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Input must be string or dict"):
|
||||||
|
tool.invoke(input=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_with_config():
|
||||||
|
"""Test that invoke properly handles configuration dictionaries."""
|
||||||
|
tool = TestTool()
|
||||||
|
# Config should be passed through to _run but not affect the result
|
||||||
|
result = tool.invoke(input={"param": "test with config"}, config={"timeout": 30})
|
||||||
|
assert result == "Tool executed with: test with config"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_with_malformed_json():
|
||||||
|
"""Test that invoke handles malformed JSON gracefully."""
|
||||||
|
tool = TestTool()
|
||||||
|
# Malformed JSON should be treated as a raw string
|
||||||
|
result = tool.invoke(input="{param: this is not valid JSON}")
|
||||||
|
assert "this is not valid JSON" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_with_nested_dict():
|
||||||
|
"""Test that invoke handles nested dictionaries properly."""
|
||||||
|
class NestedToolInput(BaseModel):
|
||||||
|
config: dict = Field(description="A nested configuration dictionary")
|
||||||
|
|
||||||
|
class NestedTool(BaseTool):
|
||||||
|
name: str = "Nested Tool"
|
||||||
|
description: str = "A tool for testing nested dictionaries"
|
||||||
|
args_schema: Type[BaseModel] = NestedToolInput
|
||||||
|
|
||||||
|
def _run(self, config: dict) -> str:
|
||||||
|
return f"Tool executed with nested config: {config}"
|
||||||
|
|
||||||
|
tool = NestedTool()
|
||||||
|
nested_input = {"config": {"key1": "value1", "key2": {"nested": "value"}}}
|
||||||
|
result = tool.invoke(input=nested_input)
|
||||||
|
assert "Tool executed with nested config" in result
|
||||||
|
assert "key1" in result
|
||||||
|
assert "nested" in result
|
||||||
Reference in New Issue
Block a user