mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 05:08:12 +00:00
Some checks failed
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Nightly Canary Release / Check for new commits (push) Has been cancelled
Nightly Canary Release / Build nightly packages (push) Has been cancelled
Nightly Canary Release / Publish nightly to PyPI (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Currently, tools have a strong input contract through `args_schema`, but no
output contract. This means that anything a tool outputs is converted to
string.
Not only the contract is weak, but the "invisible" conversion to string can
have unexpected effects when the tool returns complex objects like dicts and
arrays.
With this PR, a tool can _optionally_ define an output contract with
`output_schema`. CrewAI validates the raw result and sends the agent JSON.
```python
class ProductResult(BaseModel):
sku: str
name: str
in_stock: bool
class ProductLookupTool(BaseTool):
name: str = "Product Lookup"
description: str = "Look up product availability by SKU."
def _run(self, sku: str) -> ProductResult:
return ProductResult(sku=sku, name="USB-C dock", in_stock=True)
```
If the result does not match the schema, CrewAI warns and falls back to
`str(raw_result)` instead of failing the run:
```python
@tool("Product Lookup", output_schema=ProductResult)
def product_lookup(sku: str) -> dict[str, object]:
return {"sku": sku, "name": "USB-C dock", "in_stock": True}
#=> RuntimeWarning: Failed to validate or serialize output from tool 'Bad Product Lookup' using output_schema 'ProductResult'... Falling back to str(raw_result).
```
This is additive and non-breaking. Existing tools do not need to change. Tools
without `output_schema` keep the old string behavior. Invalid typed outputs
warn and fall back to the old formatting path.
522 lines
15 KiB
Python
522 lines
15 KiB
Python
import json
|
|
|
|
from crewai.tools.structured_tool import CrewStructuredTool
|
|
from pydantic import BaseModel, Field, RootModel
|
|
import pytest
|
|
|
|
|
|
@pytest.fixture
|
|
def basic_function():
|
|
def test_func(param1: str, param2: int = 0) -> str:
|
|
"""Test function with basic params."""
|
|
return f"{param1} {param2}"
|
|
|
|
return test_func
|
|
|
|
|
|
@pytest.fixture
|
|
def schema_class():
|
|
class TestSchema(BaseModel):
|
|
param1: str
|
|
param2: int = Field(default=0)
|
|
|
|
return TestSchema
|
|
|
|
|
|
def test_initialization(basic_function, schema_class):
|
|
"""Test basic initialization of CrewStructuredTool"""
|
|
tool = CrewStructuredTool(
|
|
name="test_tool",
|
|
description="Test tool description",
|
|
func=basic_function,
|
|
args_schema=schema_class,
|
|
)
|
|
|
|
assert tool.name == "test_tool"
|
|
assert tool.description == "Test tool description"
|
|
assert tool.func == basic_function
|
|
assert tool.args_schema == schema_class
|
|
|
|
|
|
def test_cache_function_passed_through(basic_function, schema_class):
|
|
"""Test that cache_function is stored on CrewStructuredTool."""
|
|
|
|
def no_cache(_args: dict, _result: str) -> bool:
|
|
return False
|
|
|
|
tool = CrewStructuredTool(
|
|
name="test_tool",
|
|
description="Test tool description",
|
|
func=basic_function,
|
|
args_schema=schema_class,
|
|
cache_function=no_cache,
|
|
)
|
|
|
|
assert tool.cache_function is no_cache
|
|
|
|
|
|
def test_base_tool_passes_cache_function_to_structured_tool():
|
|
"""Test that BaseTool.to_structured_tool propagates cache_function."""
|
|
from crewai.tools import BaseTool
|
|
|
|
def no_cache(_args: dict, _result: str) -> bool:
|
|
return False
|
|
|
|
class MyCacheTool(BaseTool):
|
|
name: str = "cache_test"
|
|
description: str = "tool for testing cache passthrough"
|
|
|
|
def _run(self, query: str = "") -> str:
|
|
return "result"
|
|
|
|
my_tool = MyCacheTool()
|
|
my_tool.cache_function = no_cache # type: ignore[assignment]
|
|
structured = my_tool.to_structured_tool()
|
|
|
|
assert structured.cache_function is no_cache
|
|
|
|
|
|
def test_from_function(basic_function):
|
|
"""Test creating tool from function"""
|
|
tool = CrewStructuredTool.from_function(
|
|
func=basic_function, name="test_tool", description="Test description"
|
|
)
|
|
|
|
assert tool.name == "test_tool"
|
|
assert tool.description == "Test description"
|
|
assert tool.func == basic_function
|
|
assert isinstance(tool.args_schema, type(BaseModel))
|
|
|
|
|
|
class StructuredOutput(BaseModel):
|
|
value: str
|
|
count: int
|
|
|
|
|
|
class StructuredOutputList(RootModel[list[StructuredOutput]]):
|
|
pass
|
|
|
|
|
|
def _build_explicit_structured_value(value: str) -> dict[str, object]:
|
|
"""Build a value."""
|
|
return {"value": value, "count": 1}
|
|
|
|
|
|
def _build_inferred_structured_value(value: str) -> StructuredOutput:
|
|
"""Build a value."""
|
|
return StructuredOutput(value=value, count=1)
|
|
|
|
|
|
def _build_structured_values(value: str) -> StructuredOutputList:
|
|
"""Build values."""
|
|
return StructuredOutputList([StructuredOutput(value=value, count=1)])
|
|
|
|
|
|
def _build_plain_structured_value(value: str) -> dict[str, object]:
|
|
"""Build a value."""
|
|
return {"value": value, "count": 1}
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("func", "result_schema", "expected_raw", "expected_agent_payload"),
|
|
[
|
|
pytest.param(
|
|
_build_explicit_structured_value,
|
|
StructuredOutput,
|
|
{"value": "crew", "count": 1},
|
|
{"value": "crew", "count": 1},
|
|
id="explicit-schema",
|
|
),
|
|
pytest.param(
|
|
_build_inferred_structured_value,
|
|
None,
|
|
StructuredOutput(value="crew", count=1),
|
|
{"value": "crew", "count": 1},
|
|
id="inferred-base-model",
|
|
),
|
|
pytest.param(
|
|
_build_structured_values,
|
|
None,
|
|
StructuredOutputList([StructuredOutput(value="crew", count=1)]),
|
|
[{"value": "crew", "count": 1}],
|
|
id="inferred-root-model",
|
|
),
|
|
],
|
|
)
|
|
def test_from_function_returns_raw_result_and_json_agent_text(
|
|
func,
|
|
result_schema,
|
|
expected_raw,
|
|
expected_agent_payload,
|
|
):
|
|
kwargs = {"result_schema": result_schema} if result_schema is not None else {}
|
|
tool = CrewStructuredTool.from_function(
|
|
func=func,
|
|
name="build_value",
|
|
**kwargs,
|
|
)
|
|
|
|
raw_result = tool.invoke({"value": "crew"})
|
|
|
|
assert raw_result == expected_raw
|
|
assert json.loads(tool.format_output_for_agent(raw_result)) == (
|
|
expected_agent_payload
|
|
)
|
|
|
|
|
|
def test_from_function_does_not_infer_non_pydantic_result_schema():
|
|
tool = CrewStructuredTool.from_function(
|
|
func=_build_plain_structured_value,
|
|
name="build_value",
|
|
)
|
|
|
|
raw_result = tool.invoke({"value": "crew"})
|
|
|
|
assert raw_result == {"value": "crew", "count": 1}
|
|
assert tool.format_output_for_agent(raw_result) == str(raw_result)
|
|
|
|
|
|
def test_invalid_typed_output_warns_and_uses_string_agent_text():
|
|
def build_value(value: str) -> dict[str, object]:
|
|
"""Build a value."""
|
|
return {"value": value, "count": "wrong"}
|
|
|
|
tool = CrewStructuredTool.from_function(
|
|
func=build_value,
|
|
name="build_value",
|
|
result_schema=StructuredOutput,
|
|
)
|
|
raw_result = tool.invoke({"value": "crew"})
|
|
|
|
with pytest.warns(
|
|
RuntimeWarning, match="Failed to validate or serialize"
|
|
) as warnings:
|
|
agent_text = tool.format_output_for_agent(raw_result)
|
|
|
|
assert raw_result == {"value": "crew", "count": "wrong"}
|
|
assert agent_text == str(raw_result)
|
|
warning_message = str(warnings[0].message)
|
|
assert "ValidationError" in warning_message
|
|
assert "wrong" not in warning_message
|
|
|
|
|
|
def test_validate_function_signature(basic_function, schema_class):
|
|
"""Test function signature validation"""
|
|
tool = CrewStructuredTool(
|
|
name="test_tool",
|
|
description="Test tool",
|
|
func=basic_function,
|
|
args_schema=schema_class,
|
|
)
|
|
|
|
tool._validate_function_signature()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ainvoke(basic_function):
|
|
"""Test asynchronous invocation"""
|
|
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
|
|
|
result = await tool.ainvoke(input={"param1": "test"})
|
|
assert result == "test 0"
|
|
|
|
|
|
def test_parse_args_dict(basic_function):
|
|
"""Test parsing dictionary arguments"""
|
|
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
|
|
|
parsed = tool._parse_args({"param1": "test", "param2": 42})
|
|
assert parsed["param1"] == "test"
|
|
assert parsed["param2"] == 42
|
|
|
|
|
|
def test_parse_args_string(basic_function):
|
|
"""Test parsing string arguments"""
|
|
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
|
|
|
parsed = tool._parse_args('{"param1": "test", "param2": 42}')
|
|
assert parsed["param1"] == "test"
|
|
assert parsed["param2"] == 42
|
|
|
|
|
|
def test_complex_types():
|
|
"""Test handling of complex parameter types"""
|
|
|
|
def complex_func(nested: dict, items: list) -> str:
|
|
"""Process complex types."""
|
|
return f"Processed {len(items)} items with {len(nested)} nested keys"
|
|
|
|
tool = CrewStructuredTool.from_function(
|
|
func=complex_func, name="test_tool", description="Test complex types"
|
|
)
|
|
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
|
|
assert result == "Processed 3 items with 1 nested keys"
|
|
|
|
|
|
def test_schema_inheritance():
|
|
"""Test tool creation with inherited schema"""
|
|
|
|
def extended_func(base_param: str, extra_param: int) -> str:
|
|
"""Test function with inherited schema."""
|
|
return f"{base_param} {extra_param}"
|
|
|
|
class BaseSchema(BaseModel):
|
|
base_param: str
|
|
|
|
class ExtendedSchema(BaseSchema):
|
|
extra_param: int
|
|
|
|
tool = CrewStructuredTool.from_function(
|
|
func=extended_func, name="test_tool", args_schema=ExtendedSchema
|
|
)
|
|
|
|
result = tool.invoke({"base_param": "test", "extra_param": 42})
|
|
assert result == "test 42"
|
|
|
|
|
|
def test_default_values_in_schema():
|
|
"""Test handling of default values in schema"""
|
|
|
|
def default_func(
|
|
required_param: str,
|
|
optional_param: str = "default",
|
|
nullable_param: int | None = None,
|
|
) -> str:
|
|
"""Test function with default values."""
|
|
return f"{required_param} {optional_param} {nullable_param}"
|
|
|
|
tool = CrewStructuredTool.from_function(
|
|
func=default_func, name="test_tool", description="Test defaults"
|
|
)
|
|
|
|
result = tool.invoke({"required_param": "test"})
|
|
assert result == "test default None"
|
|
|
|
result = tool.invoke(
|
|
{"required_param": "test", "optional_param": "custom", "nullable_param": 42}
|
|
)
|
|
assert result == "test custom 42"
|
|
|
|
|
|
@pytest.fixture
|
|
def custom_tool_decorator():
|
|
from crewai.tools import tool
|
|
|
|
@tool("custom_tool", result_as_answer=True)
|
|
async def custom_tool():
|
|
"""This is a tool that does something"""
|
|
return "Hello World from Custom Tool"
|
|
|
|
return custom_tool
|
|
|
|
|
|
@pytest.fixture
|
|
def custom_tool():
|
|
from crewai.tools import BaseTool
|
|
|
|
class CustomTool(BaseTool):
|
|
name: str = "my_tool"
|
|
description: str = "This is a tool that does something"
|
|
result_as_answer: bool = True
|
|
|
|
async def _run(self):
|
|
return "Hello World from Custom Tool"
|
|
|
|
return CustomTool()
|
|
|
|
|
|
def build_simple_crew(tool):
|
|
from crewai import Agent, Crew, Task
|
|
|
|
agent1 = Agent(
|
|
role="Simple role",
|
|
goal="Simple goal",
|
|
backstory="Simple backstory",
|
|
tools=[tool],
|
|
)
|
|
|
|
say_hi_task = Task(
|
|
description="Use the custom tool result as answer.",
|
|
agent=agent1,
|
|
expected_output="Use the tool result",
|
|
)
|
|
|
|
return Crew(agents=[agent1], tasks=[say_hi_task])
|
|
|
|
|
|
@pytest.mark.vcr()
|
|
def test_async_tool_using_within_isolated_crew(custom_tool):
|
|
crew = build_simple_crew(custom_tool)
|
|
result = crew.kickoff()
|
|
|
|
assert result.raw == "Hello World from Custom Tool"
|
|
|
|
|
|
@pytest.mark.vcr()
|
|
def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
|
crew = build_simple_crew(custom_tool_decorator)
|
|
result = crew.kickoff()
|
|
|
|
assert result.raw == "Hello World from Custom Tool"
|
|
|
|
|
|
@pytest.mark.vcr()
|
|
def test_async_tool_within_flow(custom_tool):
|
|
from crewai.flow.flow import Flow
|
|
|
|
class StructuredExampleFlow(Flow):
|
|
from crewai.flow.flow import start
|
|
|
|
@start()
|
|
async def start(self):
|
|
crew = build_simple_crew(custom_tool)
|
|
return await crew.kickoff_async()
|
|
|
|
flow = StructuredExampleFlow()
|
|
result = flow.kickoff()
|
|
assert result.raw == "Hello World from Custom Tool"
|
|
|
|
|
|
@pytest.mark.vcr()
|
|
def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
|
|
from crewai.flow.flow import Flow
|
|
|
|
class StructuredExampleFlow(Flow):
|
|
from crewai.flow.flow import start
|
|
|
|
@start()
|
|
async def start(self):
|
|
crew = build_simple_crew(custom_tool_decorator)
|
|
return await crew.kickoff_async()
|
|
|
|
flow = StructuredExampleFlow()
|
|
result = flow.kickoff()
|
|
assert result.raw == "Hello World from Custom Tool"
|
|
|
|
|
|
def test_structured_tool_invoke_calls_func_only_once():
|
|
"""Test that CrewStructuredTool.invoke() calls the underlying function exactly once."""
|
|
call_count = 0
|
|
call_history = []
|
|
|
|
def counting_function(param: str) -> str:
|
|
"""Function that tracks how many times it's called."""
|
|
nonlocal call_count
|
|
call_count += 1
|
|
call_history.append(f"Call #{call_count} with param: {param}")
|
|
return f"Result from call #{call_count}: {param}"
|
|
|
|
tool = CrewStructuredTool.from_function(
|
|
func=counting_function,
|
|
name="direct_test_tool",
|
|
description="Tool to test direct invoke() method",
|
|
)
|
|
|
|
result = tool.invoke({"param": "test_value"})
|
|
|
|
# Critical assertions that would catch the duplicate execution bug
|
|
assert call_count == 1, (
|
|
f"DUPLICATE EXECUTION BUG: Function was called {call_count} times instead of 1. "
|
|
f"This means CrewStructuredTool.invoke() has duplicate function calls. "
|
|
f"Call history: {call_history}"
|
|
)
|
|
|
|
assert len(call_history) == 1, (
|
|
f"Expected 1 call in history, got {len(call_history)}: {call_history}"
|
|
)
|
|
|
|
assert call_history[0] == "Call #1 with param: test_value", (
|
|
f"Expected 'Call #1 with param: test_value', got: {call_history[0]}"
|
|
)
|
|
|
|
assert result == "Result from call #1: test_value", (
|
|
f"Expected result from first call, got: {result}"
|
|
)
|
|
|
|
|
|
def test_structured_tool_invoke_multiple_calls_increment_correctly():
|
|
"""Test multiple calls to invoke() to ensure each increments correctly."""
|
|
call_count = 0
|
|
|
|
def incrementing_function(value: int) -> int:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return value + call_count
|
|
|
|
tool = CrewStructuredTool.from_function(
|
|
func=incrementing_function,
|
|
name="incrementing_tool",
|
|
description="Tool that increments on each call",
|
|
)
|
|
|
|
result1 = tool.invoke({"value": 10})
|
|
assert call_count == 1, (
|
|
f"After first invoke, expected call_count=1, got {call_count}"
|
|
)
|
|
assert result1 == 11, f"Expected 11 (10+1), got {result1}"
|
|
|
|
result2 = tool.invoke({"value": 20})
|
|
assert call_count == 2, (
|
|
f"After second invoke, expected call_count=2, got {call_count}"
|
|
)
|
|
assert result2 == 22, f"Expected 22 (20+2), got {result2}"
|
|
|
|
result3 = tool.invoke({"value": 30})
|
|
assert call_count == 3, (
|
|
f"After third invoke, expected call_count=3, got {call_count}"
|
|
)
|
|
assert result3 == 33, f"Expected 33 (30+3), got {result3}"
|
|
|
|
|
|
def test_structured_tool_invoke_with_side_effects():
|
|
"""Test that side effects only happen once per invoke() call."""
|
|
side_effects = []
|
|
|
|
def side_effect_function(action: str) -> str:
|
|
side_effects.append(f"SIDE_EFFECT: {action} executed at call")
|
|
return f"Action {action} completed"
|
|
|
|
tool = CrewStructuredTool.from_function(
|
|
func=side_effect_function,
|
|
name="side_effect_tool",
|
|
description="Tool with observable side effects",
|
|
)
|
|
|
|
result = tool.invoke({"action": "write_file"})
|
|
|
|
assert len(side_effects) == 1, (
|
|
f"SIDE EFFECT BUG: Expected 1 side effect, got {len(side_effects)}. "
|
|
f"This indicates the function was called multiple times. "
|
|
f"Side effects: {side_effects}"
|
|
)
|
|
|
|
assert side_effects[0] == "SIDE_EFFECT: write_file executed at call"
|
|
assert result == "Action write_file completed"
|
|
|
|
|
|
def test_structured_tool_invoke_exception_handling():
|
|
"""Test that exceptions don't cause duplicate execution."""
|
|
call_count = 0
|
|
|
|
def failing_function(should_fail: bool) -> str:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if should_fail:
|
|
raise ValueError(f"Intentional failure on call #{call_count}")
|
|
return f"Success on call #{call_count}"
|
|
|
|
tool = CrewStructuredTool.from_function(
|
|
func=failing_function, name="failing_tool", description="Tool that can fail"
|
|
)
|
|
|
|
result = tool.invoke({"should_fail": False})
|
|
assert call_count == 1, f"Expected 1 call for success case, got {call_count}"
|
|
assert result == "Success on call #1"
|
|
|
|
call_count = 0
|
|
|
|
with pytest.raises(ValueError, match="Intentional failure on call #1"):
|
|
tool.invoke({"should_fail": True})
|
|
|
|
assert call_count == 1
|