Add typed output schemas for tools

Tools can now declare an `output_schema`, set explicitly or inferred
from a Pydantic return annotation. `format_output_for_agent` validates
the raw result against it and serializes to JSON for the agent, while
`run` keeps returning the raw value. Falls back to `str(raw_result)`
with a warning when validation or serialization fails.
This commit is contained in:
Vinicius Brasil
2026-06-18 20:36:45 -07:00
parent 854c67d21c
commit ba7533ed9d
4 changed files with 450 additions and 3 deletions

View File

@@ -33,6 +33,8 @@ from typing_extensions import TypeIs
from crewai.tools.structured_tool import (
CrewStructuredTool,
_deserialize_schema,
_format_tool_output_for_agent,
_infer_output_schema_from_callable,
_serialize_schema,
build_schema_hint,
)
@@ -149,6 +151,11 @@ class BaseTool(BaseModel, ABC):
validate_default=True,
description="The schema for the arguments that the tool accepts.",
)
output_schema: type[PydanticBaseModel] | None = Field(
default=None,
validate_default=True,
description="The schema for the output that the tool returns.",
)
@field_serializer("args_schema", when_used="json")
def _serialize_args_schema(
@@ -156,6 +163,12 @@ class BaseTool(BaseModel, ABC):
) -> dict[str, Any] | None:
return _serialize_schema(schema)
@field_serializer("output_schema", when_used="json")
def _serialize_output_schema(
self, schema: type[PydanticBaseModel] | None
) -> dict[str, Any] | None:
return _serialize_schema(schema)
description_updated: bool = Field(
default=False, description="Flag to check if the description has been updated."
)
@@ -233,6 +246,17 @@ class BaseTool(BaseModel, ABC):
return create_model(f"{cls.__name__}Schema", **fields)
@field_validator("output_schema", mode="before")
@classmethod
def _default_output_schema(
cls, v: type[PydanticBaseModel] | dict[str, Any] | None
) -> type[PydanticBaseModel] | None:
if isinstance(v, dict):
return _deserialize_schema(v)
if v is not None:
return v
return _infer_output_schema_from_callable(cls._run)
@field_validator("max_usage_count", mode="before")
@classmethod
def validate_max_usage_count(cls, v: int | None) -> int | None:
@@ -340,6 +364,10 @@ class BaseTool(BaseModel, ABC):
"Override _arun for async support or use run() for sync execution."
)
def format_output_for_agent(self, raw_result: Any) -> str:
"""Format a raw tool result into the string representation sent to an agent."""
return _format_tool_output_for_agent(self, raw_result)
def reset_usage_count(self) -> None:
"""Reset the current usage count to zero."""
self.current_usage_count = 0
@@ -369,6 +397,7 @@ class BaseTool(BaseModel, ABC):
name=self.name,
description=self.description,
args_schema=self.args_schema,
output_schema=self.output_schema,
func=self._run,
result_as_answer=self.result_as_answer,
max_usage_count=self.max_usage_count,
@@ -390,6 +419,9 @@ class BaseTool(BaseModel, ABC):
raise ValueError("The provided tool must have a callable 'func' attribute.")
args_schema = getattr(tool, "args_schema", None)
output_schema = getattr(tool, "output_schema", None)
if output_schema is None:
output_schema = _infer_output_schema_from_callable(tool.func)
if args_schema is None:
func_signature = signature(tool.func)
@@ -420,6 +452,7 @@ class BaseTool(BaseModel, ABC):
description=getattr(tool, "description", ""),
func=tool.func,
args_schema=args_schema,
output_schema=output_schema,
)
def _set_args_schema(self) -> None:
@@ -568,6 +601,9 @@ class Tool(BaseTool, Generic[P, R]):
raise ValueError("The provided tool must have a callable 'func' attribute.")
args_schema = getattr(tool, "args_schema", None)
output_schema = getattr(tool, "output_schema", None)
if output_schema is None:
output_schema = _infer_output_schema_from_callable(tool.func)
if args_schema is None:
func_signature = signature(tool.func)
@@ -598,6 +634,7 @@ class Tool(BaseTool, Generic[P, R]):
description=getattr(tool, "description", ""),
func=tool.func,
args_schema=args_schema,
output_schema=output_schema,
)
@@ -621,6 +658,7 @@ def tool(
name: str,
/,
*,
output_schema: type[BaseModel] | None = ...,
result_as_answer: bool = ...,
max_usage_count: int | None = ...,
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
@@ -629,6 +667,7 @@ def tool(
@overload
def tool(
*,
output_schema: type[BaseModel] | None = ...,
result_as_answer: bool = ...,
max_usage_count: int | None = ...,
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
@@ -636,6 +675,7 @@ def tool(
def tool(
*args: Callable[P2, R2] | str,
output_schema: type[BaseModel] | None = None,
result_as_answer: bool = False,
max_usage_count: int | None = None,
) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]:
@@ -649,6 +689,7 @@ def tool(
Args:
*args: Either the function to decorate or a custom tool name.
result_as_answer: If True, the tool result becomes the final agent answer.
output_schema: Optional schema for the output that the tool returns.
max_usage_count: Maximum times this tool can be used. None means unlimited.
Returns:
@@ -690,12 +731,16 @@ def tool(
class_name = "".join(tool_name.split()).title()
args_schema = create_model(class_name, **fields)
resolved_output_schema = (
output_schema or _infer_output_schema_from_callable(f)
)
return Tool(
name=tool_name,
description=f.__doc__,
func=f,
args_schema=args_schema,
output_schema=resolved_output_schema,
result_as_answer=result_as_answer,
max_usage_count=max_usage_count,
current_usage_count=0,

View File

@@ -5,7 +5,8 @@ from collections.abc import Callable
import inspect
import json
import textwrap
from typing import TYPE_CHECKING, Annotated, Any, get_type_hints
from typing import TYPE_CHECKING, Annotated, Any, cast, get_type_hints
import warnings
from pydantic import (
BaseModel,
@@ -36,6 +37,47 @@ def _deserialize_schema(v: Any) -> type[BaseModel] | None:
return None
def _infer_output_schema_from_callable(
func: Callable[..., Any],
) -> type[BaseModel] | None:
try:
return_annotation = get_type_hints(func).get("return", inspect.Signature.empty)
except Exception:
return_annotation = inspect.signature(func).return_annotation
if isinstance(return_annotation, type) and issubclass(return_annotation, BaseModel):
return return_annotation
return None
def _format_tool_output_for_agent(tool: Any, raw_result: Any) -> str:
output_schema = getattr(tool, "output_schema", None)
if output_schema is None:
return str(raw_result)
try:
validation_input = raw_result
if isinstance(raw_result, BaseModel) and not isinstance(
raw_result, output_schema
):
validation_input = raw_result.model_dump()
validated = output_schema.model_validate(validation_input)
return cast(str, validated.model_dump_json())
except Exception as exc:
warnings.warn(
(
f"Failed to validate or serialize output from tool "
f"'{getattr(tool, 'name', '<unknown>')}' using output_schema "
f"'{output_schema.__name__}': {exc}. Falling back to str(raw_result)."
),
RuntimeWarning,
stacklevel=2,
)
return str(raw_result)
if TYPE_CHECKING:
pass
@@ -81,6 +123,11 @@ class CrewStructuredTool(BaseModel):
BeforeValidator(_deserialize_schema),
PlainSerializer(_serialize_schema),
] = Field(default=None)
output_schema: Annotated[
type[BaseModel] | None,
BeforeValidator(_deserialize_schema),
PlainSerializer(_serialize_schema),
] = Field(default=None)
func: Any = Field(default=None, exclude=True)
result_as_answer: bool = Field(default=False)
max_usage_count: int | None = Field(default=None)
@@ -103,6 +150,7 @@ class CrewStructuredTool(BaseModel):
description: str | None = None,
return_direct: bool = False,
args_schema: type[BaseModel] | None = None,
output_schema: type[BaseModel] | None = None,
infer_schema: bool = True,
**kwargs: Any,
) -> CrewStructuredTool:
@@ -114,6 +162,7 @@ class CrewStructuredTool(BaseModel):
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the output directly
args_schema: Optional schema for the function arguments
output_schema: Optional schema for the function output
infer_schema: Whether to infer the schema from the function signature
**kwargs: Additional arguments to pass to the tool
@@ -149,10 +198,16 @@ class CrewStructuredTool(BaseModel):
name=name,
description=description,
args_schema=schema,
output_schema=output_schema or _infer_output_schema_from_callable(func),
func=func,
result_as_answer=return_direct,
**kwargs,
)
def format_output_for_agent(self, raw_result: Any) -> str:
"""Format a raw tool result into the string representation sent to an agent."""
return _format_tool_output_for_agent(self, raw_result)
@staticmethod
def _create_schema_from_function(
name: str,

View File

@@ -1,12 +1,13 @@
import asyncio
from collections.abc import Callable
import json
from unittest.mock import patch
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
from crewai.tools import BaseTool, tool
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, RootModel
import pytest
@@ -351,6 +352,240 @@ class TestToolDecoratorRunValidation:
assert result == "Hello, World!"
class SearchOutput(BaseModel):
query: str
score: float
class SearchResults(RootModel[list[SearchOutput]]):
pass
class ExplicitSearchTool(BaseTool):
name: str = "search"
description: str = "Search for a query"
output_schema: type[BaseModel] = SearchOutput
def _run(self, query: str) -> dict[str, object]:
return {"query": query, "score": 0.8}
class InferredSearchTool(BaseTool):
name: str = "search"
description: str = "Search for a query"
def _run(self, query: str) -> SearchOutput:
return SearchOutput(query=query, score=0.7)
class RootSearchTool(BaseTool):
name: str = "search"
description: str = "Search for a query"
def _run(self, query: str) -> SearchResults:
return SearchResults([SearchOutput(query=query, score=1.0)])
class DictAnnotatedSearchTool(BaseTool):
name: str = "search"
description: str = "Search for a query"
def _run(self, query: str) -> dict[str, object]:
return {"query": query, "score": 0.5}
def _make_explicit_decorator_tool() -> BaseTool:
@tool("search", output_schema=SearchOutput)
def search(query: str) -> dict[str, object]:
"""Search for a query."""
return {"query": query, "score": 0.8}
return search
def _make_inferred_decorator_tool() -> BaseTool:
@tool("search")
def search(query: str) -> SearchOutput:
"""Search for a query."""
return SearchOutput(query=query, score=0.6)
return search
def _make_root_decorator_tool() -> BaseTool:
@tool("search")
def search(query: str) -> SearchResults:
"""Search for a query."""
return SearchResults([SearchOutput(query=query, score=1.0)])
return search
class TestToolOutputSchema:
"""Tests for typed tool output behavior."""
@pytest.mark.parametrize(
("tool_cls", "expected_raw", "expected_agent_payload"),
[
pytest.param(
ExplicitSearchTool,
{"query": "crew", "score": 0.8},
{"query": "crew", "score": 0.8},
id="explicit-schema",
),
pytest.param(
InferredSearchTool,
SearchOutput(query="crew", score=0.7),
{"query": "crew", "score": 0.7},
id="inferred-base-model",
),
pytest.param(
RootSearchTool,
SearchResults([SearchOutput(query="crew", score=1.0)]),
[{"query": "crew", "score": 1.0}],
id="inferred-root-model",
),
],
)
def test_base_tools_return_raw_result_and_json_agent_text(
self,
tool_cls: type[BaseTool],
expected_raw: object,
expected_agent_payload: object,
) -> None:
t = tool_cls()
raw_result = t.run(query="crew")
assert raw_result == expected_raw
assert json.loads(t.format_output_for_agent(raw_result)) == (
expected_agent_payload
)
def test_base_tool_does_not_infer_non_pydantic_return_annotation(self) -> None:
t = DictAnnotatedSearchTool()
raw_result = t.run(query="crew")
assert raw_result == {"query": "crew", "score": 0.5}
assert t.format_output_for_agent(raw_result) == str(raw_result)
@pytest.mark.parametrize(
("make_tool", "expected_raw", "expected_agent_payload"),
[
pytest.param(
_make_explicit_decorator_tool,
{"query": "crew", "score": 0.8},
{"query": "crew", "score": 0.8},
id="explicit-schema",
),
pytest.param(
_make_inferred_decorator_tool,
SearchOutput(query="crew", score=0.6),
{"query": "crew", "score": 0.6},
id="inferred-base-model",
),
pytest.param(
_make_root_decorator_tool,
SearchResults([SearchOutput(query="crew", score=1.0)]),
[{"query": "crew", "score": 1.0}],
id="inferred-root-model",
),
],
)
def test_decorator_tools_return_raw_result_and_json_agent_text(
self,
make_tool: Callable[[], BaseTool],
expected_raw: object,
expected_agent_payload: object,
) -> None:
search = make_tool()
raw_result = search.run(query="crew")
assert raw_result == expected_raw
assert json.loads(search.format_output_for_agent(raw_result)) == (
expected_agent_payload
)
def test_decorator_tool_does_not_infer_non_pydantic_return_annotation(
self,
) -> None:
@tool("search")
def search(query: str) -> dict[str, object]:
"""Search for a query."""
return {"query": query, "score": 0.5}
raw_result = search.run(query="crew")
assert raw_result == {"query": "crew", "score": 0.5}
assert search.format_output_for_agent(raw_result) == str(raw_result)
def test_explicit_output_schema_wins_over_return_annotation(self) -> None:
class AlternateOutput(BaseModel):
value: str
@tool("search", output_schema=AlternateOutput)
def search(query: str) -> SearchOutput:
"""Search for a query."""
return SearchOutput(query=query, score=0.6)
raw_result = search.run(query="crew")
with pytest.warns(RuntimeWarning, match="AlternateOutput"):
agent_text = search.format_output_for_agent(raw_result)
assert raw_result == SearchOutput(query="crew", score=0.6)
assert agent_text == str(raw_result)
def test_invalid_typed_output_warns_and_uses_string_agent_text(
self,
) -> None:
@tool("search", output_schema=SearchOutput)
def search(query: str) -> dict[str, object]:
"""Search for a query."""
return {"query": query, "score": "not-a-float"}
raw_result = search.run(query="crew")
with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"):
agent_text = search.format_output_for_agent(raw_result)
assert raw_result == {"query": "crew", "score": "not-a-float"}
assert agent_text == str(raw_result)
def test_unserializable_typed_output_warns_and_uses_string_agent_text(
self,
) -> None:
class OpaqueOutput(BaseModel):
value: object
raw_result = OpaqueOutput(value=object())
@tool("opaque", output_schema=OpaqueOutput)
def opaque() -> OpaqueOutput:
"""Return an opaque object."""
return raw_result
result = opaque.run()
with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"):
agent_text = opaque.format_output_for_agent(result)
assert result is raw_result
assert agent_text == str(raw_result)
def test_output_schema_behavior_carries_over_to_structured_tool(self) -> None:
structured = ExplicitSearchTool().to_structured_tool()
raw_result = structured.invoke({"query": "crew"})
assert raw_result == {"query": "crew", "score": 0.8}
assert json.loads(structured.format_output_for_agent(raw_result)) == {
"query": "crew",
"score": 0.8,
}
# Async arun() Schema Validation Tests

View File

@@ -1,5 +1,7 @@
import json
from crewai.tools.structured_tool import CrewStructuredTool
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, RootModel
import pytest
@@ -86,6 +88,116 @@ def test_from_function(basic_function):
assert isinstance(tool.args_schema, type(BaseModel))
class StructuredOutput(BaseModel):
value: str
count: int
class StructuredOutputList(RootModel[list[StructuredOutput]]):
pass
def _build_explicit_structured_value(value: str) -> dict[str, object]:
"""Build a value."""
return {"value": value, "count": 1}
def _build_inferred_structured_value(value: str) -> StructuredOutput:
"""Build a value."""
return StructuredOutput(value=value, count=1)
def _build_structured_values(value: str) -> StructuredOutputList:
"""Build values."""
return StructuredOutputList([StructuredOutput(value=value, count=1)])
def _build_plain_structured_value(value: str) -> dict[str, object]:
"""Build a value."""
return {"value": value, "count": 1}
@pytest.mark.parametrize(
("func", "output_schema", "expected_raw", "expected_agent_payload"),
[
pytest.param(
_build_explicit_structured_value,
StructuredOutput,
{"value": "crew", "count": 1},
{"value": "crew", "count": 1},
id="explicit-schema",
),
pytest.param(
_build_inferred_structured_value,
None,
StructuredOutput(value="crew", count=1),
{"value": "crew", "count": 1},
id="inferred-base-model",
),
pytest.param(
_build_structured_values,
None,
StructuredOutputList([StructuredOutput(value="crew", count=1)]),
[{"value": "crew", "count": 1}],
id="inferred-root-model",
),
],
)
def test_from_function_returns_raw_result_and_json_agent_text(
func,
output_schema,
expected_raw,
expected_agent_payload,
):
"""Typed structured tools return raw values and format JSON for the agent."""
kwargs = {"output_schema": output_schema} if output_schema is not None else {}
tool = CrewStructuredTool.from_function(
func=func,
name="build_value",
**kwargs,
)
raw_result = tool.invoke({"value": "crew"})
assert raw_result == expected_raw
assert json.loads(tool.format_output_for_agent(raw_result)) == (
expected_agent_payload
)
def test_from_function_does_not_infer_non_pydantic_output_schema():
"""Non-Pydantic return annotations use the plain string formatter."""
tool = CrewStructuredTool.from_function(
func=_build_plain_structured_value,
name="build_value",
)
raw_result = tool.invoke({"value": "crew"})
assert raw_result == {"value": "crew", "count": 1}
assert tool.format_output_for_agent(raw_result) == str(raw_result)
def test_invalid_typed_output_warns_and_uses_string_agent_text():
"""Invalid structured output leaves the raw result unchanged."""
def build_value(value: str) -> dict[str, object]:
"""Build a value."""
return {"value": value, "count": "wrong"}
tool = CrewStructuredTool.from_function(
func=build_value,
name="build_value",
output_schema=StructuredOutput,
)
raw_result = tool.invoke({"value": "crew"})
with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"):
agent_text = tool.format_output_for_agent(raw_result)
assert raw_result == {"value": "crew", "count": "wrong"}
assert agent_text == str(raw_result)
def test_validate_function_signature(basic_function, schema_class):
"""Test function signature validation"""
tool = CrewStructuredTool(