Compare commits

...

3 Commits

2 changed files with 206 additions and 1 deletions

View File

@@ -173,11 +173,18 @@ class CrewStructuredTool:
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
"""Parse and validate the input arguments against the schema.
This method handles different input formats from various LLM providers,
including nested dictionaries with 'value' fields that some providers use.
Args:
raw_args: The raw arguments to parse, either as a string or dict
raw_args: The raw arguments to parse, either as a string or dict.
Supports nested dictionaries with 'value' field for LLM provider compatibility.
Returns:
The validated arguments as a dictionary
Raises:
ValueError: If argument parsing or validation fails
"""
if isinstance(raw_args, str):
try:
@@ -187,6 +194,31 @@ class CrewStructuredTool:
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments as JSON: {e}")
# Handle nested dictionaries with 'value' field for all parameter types
if isinstance(raw_args, dict):
schema_fields = self.args_schema.model_fields
for field_name, field_value in list(raw_args.items()):
# Check if this field exists in the schema
if field_name in schema_fields:
# Handle nested dictionaries with 'value' field
if isinstance(field_value, dict):
if 'value' in field_value:
# Extract the value from the nested dictionary
value = field_value['value']
self._logger.debug(f"Extracting value from nested dict for {field_name}")
expected_type = schema_fields[field_name].annotation
if expected_type in (str, int, float, bool) and not isinstance(value, expected_type):
self._logger.warning(
f"Type mismatch for {field_name}: expected {expected_type}, got {type(value)}"
)
raw_args[field_name] = value
else:
self._logger.debug(f"Nested dict for {field_name} has no 'value' key")
try:
validated_args = self.args_schema.model_validate(raw_args)
return validated_args.model_dump()

View File

@@ -0,0 +1,173 @@
import pytest
from pydantic import BaseModel, Field
from crewai.tools.structured_tool import CrewStructuredTool
class StringInputSchema(BaseModel):
"""Schema with a string input field."""
query: str = Field(description="A string input parameter")
class IntInputSchema(BaseModel):
"""Schema with an integer input field."""
number: int = Field(description="An integer input parameter")
class ComplexInputSchema(BaseModel):
"""Schema with multiple fields of different types."""
text: str = Field(description="A string parameter")
number: int = Field(description="An integer parameter")
flag: bool = Field(description="A boolean parameter")
def test_parse_args_with_string_input():
"""Test that string inputs are parsed correctly."""
def test_func(query: str) -> str:
return f"Processed: {query}"
tool = CrewStructuredTool.from_function(
func=test_func,
name="StringTool",
description="A tool that processes string input"
)
# Test with direct string input
result = tool._parse_args({"query": "test string"})
assert result["query"] == "test string"
assert isinstance(result["query"], str)
# Test with JSON string input
result = tool._parse_args('{"query": "json string"}')
assert result["query"] == "json string"
assert isinstance(result["query"], str)
def test_parse_args_with_nested_dict_for_string():
"""Test that nested dictionaries with 'value' field are handled correctly for string fields."""
def test_func(query: str) -> str:
return f"Processed: {query}"
tool = CrewStructuredTool.from_function(
func=test_func,
name="StringTool",
description="A tool that processes string input"
)
# Test with nested dict input (simulating the issue from different LLM providers)
nested_input = {"query": {"description": "A string input parameter", "value": "test value"}}
result = tool._parse_args(nested_input)
assert result["query"] == "test value"
assert isinstance(result["query"], str)
def test_parse_args_with_nested_dict_for_int():
"""Test that nested dictionaries with 'value' field are handled correctly for int fields."""
def test_func(number: int) -> str:
return f"Processed: {number}"
tool = CrewStructuredTool.from_function(
func=test_func,
name="IntTool",
description="A tool that processes integer input"
)
# Test with nested dict input for int field
nested_input = {"number": {"description": "An integer input parameter", "value": 42}}
result = tool._parse_args(nested_input)
assert result["number"] == 42
assert isinstance(result["number"], int)
def test_parse_args_with_complex_input():
"""Test that complex inputs with multiple fields are handled correctly."""
def test_func(text: str, number: int, flag: bool) -> str:
return f"Processed: {text}, {number}, {flag}"
tool = CrewStructuredTool.from_function(
func=test_func,
name="ComplexTool",
description="A tool that processes complex input"
)
# Test with mixed nested dict input
complex_input = {
"text": {"description": "A string parameter", "value": "test text"},
"number": 42,
"flag": True
}
result = tool._parse_args(complex_input)
assert result["text"] == "test text"
assert isinstance(result["text"], str)
assert result["number"] == 42
assert isinstance(result["number"], int)
assert result["flag"] is True
assert isinstance(result["flag"], bool)
def test_invoke_with_nested_dict():
"""Test that invoking a tool with nested dict input works correctly."""
def test_func(query: str) -> str:
return f"Processed: {query}"
tool = CrewStructuredTool.from_function(
func=test_func,
name="StringTool",
description="A tool that processes string input"
)
# Test invoking with nested dict input
nested_input = {"query": {"description": "A string input parameter", "value": "test value"}}
result = tool.invoke(nested_input)
assert result == "Processed: test value"
def test_nested_dict_without_value_key():
"""Test that nested dictionaries without 'value' field raise appropriate errors."""
def test_func(query: str) -> str:
return f"Processed: {query}"
tool = CrewStructuredTool.from_function(
func=test_func,
name="StringTool",
description="A tool that processes string input"
)
# Test with nested dict without 'value' key
invalid_input = {"query": {"description": "A string input parameter", "other_key": "test"}}
with pytest.raises(ValueError):
tool._parse_args(invalid_input)
def test_empty_nested_dict():
"""Test handling of empty nested dictionaries."""
def test_func(query: str) -> str:
return f"Processed: {query}"
tool = CrewStructuredTool.from_function(
func=test_func,
name="StringTool",
description="A tool that processes string input"
)
# Test with empty nested dict
empty_dict_input = {"query": {}}
with pytest.raises(ValueError):
tool._parse_args(empty_dict_input)
def test_deeply_nested_structure():
"""Test handling of deeply nested structures."""
def test_func(query: str) -> str:
return f"Processed: {query}"
tool = CrewStructuredTool.from_function(
func=test_func,
name="StringTool",
description="A tool that processes string input"
)
# Test with deeply nested structure
deeply_nested = {"query": {"nested": {"deeper": {"value": "deep value"}}}}
with pytest.raises(ValueError):
tool._parse_args(deeply_nested)