mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 13:48:09 +00:00
Add typed output schemas for tools
Tools can now declare an `output_schema`, set explicitly or inferred from a Pydantic return annotation. `format_output_for_agent` validates the raw result against it and serializes to JSON for the agent, while `run` keeps returning the raw value. Falls back to `str(raw_result)` with a warning when validation or serialization fails.
This commit is contained in:
@@ -33,6 +33,8 @@ from typing_extensions import TypeIs
|
||||
from crewai.tools.structured_tool import (
|
||||
CrewStructuredTool,
|
||||
_deserialize_schema,
|
||||
_format_tool_output_for_agent,
|
||||
_infer_output_schema_from_callable,
|
||||
_serialize_schema,
|
||||
build_schema_hint,
|
||||
)
|
||||
@@ -149,6 +151,11 @@ class BaseTool(BaseModel, ABC):
|
||||
validate_default=True,
|
||||
description="The schema for the arguments that the tool accepts.",
|
||||
)
|
||||
output_schema: type[PydanticBaseModel] | None = Field(
|
||||
default=None,
|
||||
validate_default=True,
|
||||
description="The schema for the output that the tool returns.",
|
||||
)
|
||||
|
||||
@field_serializer("args_schema", when_used="json")
|
||||
def _serialize_args_schema(
|
||||
@@ -156,6 +163,12 @@ class BaseTool(BaseModel, ABC):
|
||||
) -> dict[str, Any] | None:
|
||||
return _serialize_schema(schema)
|
||||
|
||||
@field_serializer("output_schema", when_used="json")
|
||||
def _serialize_output_schema(
|
||||
self, schema: type[PydanticBaseModel] | None
|
||||
) -> dict[str, Any] | None:
|
||||
return _serialize_schema(schema)
|
||||
|
||||
description_updated: bool = Field(
|
||||
default=False, description="Flag to check if the description has been updated."
|
||||
)
|
||||
@@ -233,6 +246,17 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
return create_model(f"{cls.__name__}Schema", **fields)
|
||||
|
||||
@field_validator("output_schema", mode="before")
|
||||
@classmethod
|
||||
def _default_output_schema(
|
||||
cls, v: type[PydanticBaseModel] | dict[str, Any] | None
|
||||
) -> type[PydanticBaseModel] | None:
|
||||
if isinstance(v, dict):
|
||||
return _deserialize_schema(v)
|
||||
if v is not None:
|
||||
return v
|
||||
return _infer_output_schema_from_callable(cls._run)
|
||||
|
||||
@field_validator("max_usage_count", mode="before")
|
||||
@classmethod
|
||||
def validate_max_usage_count(cls, v: int | None) -> int | None:
|
||||
@@ -340,6 +364,10 @@ class BaseTool(BaseModel, ABC):
|
||||
"Override _arun for async support or use run() for sync execution."
|
||||
)
|
||||
|
||||
def format_output_for_agent(self, raw_result: Any) -> str:
|
||||
"""Format a raw tool result into the string representation sent to an agent."""
|
||||
return _format_tool_output_for_agent(self, raw_result)
|
||||
|
||||
def reset_usage_count(self) -> None:
|
||||
"""Reset the current usage count to zero."""
|
||||
self.current_usage_count = 0
|
||||
@@ -369,6 +397,7 @@ class BaseTool(BaseModel, ABC):
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
args_schema=self.args_schema,
|
||||
output_schema=self.output_schema,
|
||||
func=self._run,
|
||||
result_as_answer=self.result_as_answer,
|
||||
max_usage_count=self.max_usage_count,
|
||||
@@ -390,6 +419,9 @@ class BaseTool(BaseModel, ABC):
|
||||
raise ValueError("The provided tool must have a callable 'func' attribute.")
|
||||
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
output_schema = getattr(tool, "output_schema", None)
|
||||
if output_schema is None:
|
||||
output_schema = _infer_output_schema_from_callable(tool.func)
|
||||
|
||||
if args_schema is None:
|
||||
func_signature = signature(tool.func)
|
||||
@@ -420,6 +452,7 @@ class BaseTool(BaseModel, ABC):
|
||||
description=getattr(tool, "description", ""),
|
||||
func=tool.func,
|
||||
args_schema=args_schema,
|
||||
output_schema=output_schema,
|
||||
)
|
||||
|
||||
def _set_args_schema(self) -> None:
|
||||
@@ -568,6 +601,9 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
raise ValueError("The provided tool must have a callable 'func' attribute.")
|
||||
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
output_schema = getattr(tool, "output_schema", None)
|
||||
if output_schema is None:
|
||||
output_schema = _infer_output_schema_from_callable(tool.func)
|
||||
|
||||
if args_schema is None:
|
||||
func_signature = signature(tool.func)
|
||||
@@ -598,6 +634,7 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
description=getattr(tool, "description", ""),
|
||||
func=tool.func,
|
||||
args_schema=args_schema,
|
||||
output_schema=output_schema,
|
||||
)
|
||||
|
||||
|
||||
@@ -621,6 +658,7 @@ def tool(
|
||||
name: str,
|
||||
/,
|
||||
*,
|
||||
output_schema: type[BaseModel] | None = ...,
|
||||
result_as_answer: bool = ...,
|
||||
max_usage_count: int | None = ...,
|
||||
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
|
||||
@@ -629,6 +667,7 @@ def tool(
|
||||
@overload
|
||||
def tool(
|
||||
*,
|
||||
output_schema: type[BaseModel] | None = ...,
|
||||
result_as_answer: bool = ...,
|
||||
max_usage_count: int | None = ...,
|
||||
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
|
||||
@@ -636,6 +675,7 @@ def tool(
|
||||
|
||||
def tool(
|
||||
*args: Callable[P2, R2] | str,
|
||||
output_schema: type[BaseModel] | None = None,
|
||||
result_as_answer: bool = False,
|
||||
max_usage_count: int | None = None,
|
||||
) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]:
|
||||
@@ -649,6 +689,7 @@ def tool(
|
||||
Args:
|
||||
*args: Either the function to decorate or a custom tool name.
|
||||
result_as_answer: If True, the tool result becomes the final agent answer.
|
||||
output_schema: Optional schema for the output that the tool returns.
|
||||
max_usage_count: Maximum times this tool can be used. None means unlimited.
|
||||
|
||||
Returns:
|
||||
@@ -690,12 +731,16 @@ def tool(
|
||||
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = create_model(class_name, **fields)
|
||||
resolved_output_schema = (
|
||||
output_schema or _infer_output_schema_from_callable(f)
|
||||
)
|
||||
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
description=f.__doc__,
|
||||
func=f,
|
||||
args_schema=args_schema,
|
||||
output_schema=resolved_output_schema,
|
||||
result_as_answer=result_as_answer,
|
||||
max_usage_count=max_usage_count,
|
||||
current_usage_count=0,
|
||||
|
||||
@@ -5,7 +5,8 @@ from collections.abc import Callable
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING, Annotated, Any, get_type_hints
|
||||
from typing import TYPE_CHECKING, Annotated, Any, cast, get_type_hints
|
||||
import warnings
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -36,6 +37,47 @@ def _deserialize_schema(v: Any) -> type[BaseModel] | None:
|
||||
return None
|
||||
|
||||
|
||||
def _infer_output_schema_from_callable(
|
||||
func: Callable[..., Any],
|
||||
) -> type[BaseModel] | None:
|
||||
try:
|
||||
return_annotation = get_type_hints(func).get("return", inspect.Signature.empty)
|
||||
except Exception:
|
||||
return_annotation = inspect.signature(func).return_annotation
|
||||
|
||||
if isinstance(return_annotation, type) and issubclass(return_annotation, BaseModel):
|
||||
return return_annotation
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _format_tool_output_for_agent(tool: Any, raw_result: Any) -> str:
|
||||
output_schema = getattr(tool, "output_schema", None)
|
||||
if output_schema is None:
|
||||
return str(raw_result)
|
||||
|
||||
try:
|
||||
validation_input = raw_result
|
||||
if isinstance(raw_result, BaseModel) and not isinstance(
|
||||
raw_result, output_schema
|
||||
):
|
||||
validation_input = raw_result.model_dump()
|
||||
|
||||
validated = output_schema.model_validate(validation_input)
|
||||
return cast(str, validated.model_dump_json())
|
||||
except Exception as exc:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Failed to validate or serialize output from tool "
|
||||
f"'{getattr(tool, 'name', '<unknown>')}' using output_schema "
|
||||
f"'{output_schema.__name__}': {exc}. Falling back to str(raw_result)."
|
||||
),
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return str(raw_result)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
@@ -81,6 +123,11 @@ class CrewStructuredTool(BaseModel):
|
||||
BeforeValidator(_deserialize_schema),
|
||||
PlainSerializer(_serialize_schema),
|
||||
] = Field(default=None)
|
||||
output_schema: Annotated[
|
||||
type[BaseModel] | None,
|
||||
BeforeValidator(_deserialize_schema),
|
||||
PlainSerializer(_serialize_schema),
|
||||
] = Field(default=None)
|
||||
func: Any = Field(default=None, exclude=True)
|
||||
result_as_answer: bool = Field(default=False)
|
||||
max_usage_count: int | None = Field(default=None)
|
||||
@@ -103,6 +150,7 @@ class CrewStructuredTool(BaseModel):
|
||||
description: str | None = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: type[BaseModel] | None = None,
|
||||
output_schema: type[BaseModel] | None = None,
|
||||
infer_schema: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> CrewStructuredTool:
|
||||
@@ -114,6 +162,7 @@ class CrewStructuredTool(BaseModel):
|
||||
description: The description of the tool. Defaults to the function docstring
|
||||
return_direct: Whether to return the output directly
|
||||
args_schema: Optional schema for the function arguments
|
||||
output_schema: Optional schema for the function output
|
||||
infer_schema: Whether to infer the schema from the function signature
|
||||
**kwargs: Additional arguments to pass to the tool
|
||||
|
||||
@@ -149,10 +198,16 @@ class CrewStructuredTool(BaseModel):
|
||||
name=name,
|
||||
description=description,
|
||||
args_schema=schema,
|
||||
output_schema=output_schema or _infer_output_schema_from_callable(func),
|
||||
func=func,
|
||||
result_as_answer=return_direct,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def format_output_for_agent(self, raw_result: Any) -> str:
|
||||
"""Format a raw tool result into the string representation sent to an agent."""
|
||||
return _format_tool_output_for_agent(self, raw_result)
|
||||
|
||||
@staticmethod
|
||||
def _create_schema_from_function(
|
||||
name: str,
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool, tool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -351,6 +352,240 @@ class TestToolDecoratorRunValidation:
|
||||
assert result == "Hello, World!"
|
||||
|
||||
|
||||
class SearchOutput(BaseModel):
|
||||
query: str
|
||||
score: float
|
||||
|
||||
|
||||
class SearchResults(RootModel[list[SearchOutput]]):
|
||||
pass
|
||||
|
||||
|
||||
class ExplicitSearchTool(BaseTool):
|
||||
name: str = "search"
|
||||
description: str = "Search for a query"
|
||||
output_schema: type[BaseModel] = SearchOutput
|
||||
|
||||
def _run(self, query: str) -> dict[str, object]:
|
||||
return {"query": query, "score": 0.8}
|
||||
|
||||
|
||||
class InferredSearchTool(BaseTool):
|
||||
name: str = "search"
|
||||
description: str = "Search for a query"
|
||||
|
||||
def _run(self, query: str) -> SearchOutput:
|
||||
return SearchOutput(query=query, score=0.7)
|
||||
|
||||
|
||||
class RootSearchTool(BaseTool):
|
||||
name: str = "search"
|
||||
description: str = "Search for a query"
|
||||
|
||||
def _run(self, query: str) -> SearchResults:
|
||||
return SearchResults([SearchOutput(query=query, score=1.0)])
|
||||
|
||||
|
||||
class DictAnnotatedSearchTool(BaseTool):
|
||||
name: str = "search"
|
||||
description: str = "Search for a query"
|
||||
|
||||
def _run(self, query: str) -> dict[str, object]:
|
||||
return {"query": query, "score": 0.5}
|
||||
|
||||
|
||||
def _make_explicit_decorator_tool() -> BaseTool:
|
||||
@tool("search", output_schema=SearchOutput)
|
||||
def search(query: str) -> dict[str, object]:
|
||||
"""Search for a query."""
|
||||
return {"query": query, "score": 0.8}
|
||||
|
||||
return search
|
||||
|
||||
|
||||
def _make_inferred_decorator_tool() -> BaseTool:
|
||||
@tool("search")
|
||||
def search(query: str) -> SearchOutput:
|
||||
"""Search for a query."""
|
||||
return SearchOutput(query=query, score=0.6)
|
||||
|
||||
return search
|
||||
|
||||
|
||||
def _make_root_decorator_tool() -> BaseTool:
|
||||
@tool("search")
|
||||
def search(query: str) -> SearchResults:
|
||||
"""Search for a query."""
|
||||
return SearchResults([SearchOutput(query=query, score=1.0)])
|
||||
|
||||
return search
|
||||
|
||||
|
||||
class TestToolOutputSchema:
|
||||
"""Tests for typed tool output behavior."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tool_cls", "expected_raw", "expected_agent_payload"),
|
||||
[
|
||||
pytest.param(
|
||||
ExplicitSearchTool,
|
||||
{"query": "crew", "score": 0.8},
|
||||
{"query": "crew", "score": 0.8},
|
||||
id="explicit-schema",
|
||||
),
|
||||
pytest.param(
|
||||
InferredSearchTool,
|
||||
SearchOutput(query="crew", score=0.7),
|
||||
{"query": "crew", "score": 0.7},
|
||||
id="inferred-base-model",
|
||||
),
|
||||
pytest.param(
|
||||
RootSearchTool,
|
||||
SearchResults([SearchOutput(query="crew", score=1.0)]),
|
||||
[{"query": "crew", "score": 1.0}],
|
||||
id="inferred-root-model",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_base_tools_return_raw_result_and_json_agent_text(
|
||||
self,
|
||||
tool_cls: type[BaseTool],
|
||||
expected_raw: object,
|
||||
expected_agent_payload: object,
|
||||
) -> None:
|
||||
t = tool_cls()
|
||||
|
||||
raw_result = t.run(query="crew")
|
||||
|
||||
assert raw_result == expected_raw
|
||||
assert json.loads(t.format_output_for_agent(raw_result)) == (
|
||||
expected_agent_payload
|
||||
)
|
||||
|
||||
def test_base_tool_does_not_infer_non_pydantic_return_annotation(self) -> None:
|
||||
t = DictAnnotatedSearchTool()
|
||||
|
||||
raw_result = t.run(query="crew")
|
||||
|
||||
assert raw_result == {"query": "crew", "score": 0.5}
|
||||
assert t.format_output_for_agent(raw_result) == str(raw_result)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("make_tool", "expected_raw", "expected_agent_payload"),
|
||||
[
|
||||
pytest.param(
|
||||
_make_explicit_decorator_tool,
|
||||
{"query": "crew", "score": 0.8},
|
||||
{"query": "crew", "score": 0.8},
|
||||
id="explicit-schema",
|
||||
),
|
||||
pytest.param(
|
||||
_make_inferred_decorator_tool,
|
||||
SearchOutput(query="crew", score=0.6),
|
||||
{"query": "crew", "score": 0.6},
|
||||
id="inferred-base-model",
|
||||
),
|
||||
pytest.param(
|
||||
_make_root_decorator_tool,
|
||||
SearchResults([SearchOutput(query="crew", score=1.0)]),
|
||||
[{"query": "crew", "score": 1.0}],
|
||||
id="inferred-root-model",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_decorator_tools_return_raw_result_and_json_agent_text(
|
||||
self,
|
||||
make_tool: Callable[[], BaseTool],
|
||||
expected_raw: object,
|
||||
expected_agent_payload: object,
|
||||
) -> None:
|
||||
search = make_tool()
|
||||
|
||||
raw_result = search.run(query="crew")
|
||||
|
||||
assert raw_result == expected_raw
|
||||
assert json.loads(search.format_output_for_agent(raw_result)) == (
|
||||
expected_agent_payload
|
||||
)
|
||||
|
||||
def test_decorator_tool_does_not_infer_non_pydantic_return_annotation(
|
||||
self,
|
||||
) -> None:
|
||||
@tool("search")
|
||||
def search(query: str) -> dict[str, object]:
|
||||
"""Search for a query."""
|
||||
return {"query": query, "score": 0.5}
|
||||
|
||||
raw_result = search.run(query="crew")
|
||||
|
||||
assert raw_result == {"query": "crew", "score": 0.5}
|
||||
assert search.format_output_for_agent(raw_result) == str(raw_result)
|
||||
|
||||
def test_explicit_output_schema_wins_over_return_annotation(self) -> None:
|
||||
class AlternateOutput(BaseModel):
|
||||
value: str
|
||||
|
||||
@tool("search", output_schema=AlternateOutput)
|
||||
def search(query: str) -> SearchOutput:
|
||||
"""Search for a query."""
|
||||
return SearchOutput(query=query, score=0.6)
|
||||
|
||||
raw_result = search.run(query="crew")
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="AlternateOutput"):
|
||||
agent_text = search.format_output_for_agent(raw_result)
|
||||
|
||||
assert raw_result == SearchOutput(query="crew", score=0.6)
|
||||
assert agent_text == str(raw_result)
|
||||
|
||||
def test_invalid_typed_output_warns_and_uses_string_agent_text(
|
||||
self,
|
||||
) -> None:
|
||||
@tool("search", output_schema=SearchOutput)
|
||||
def search(query: str) -> dict[str, object]:
|
||||
"""Search for a query."""
|
||||
return {"query": query, "score": "not-a-float"}
|
||||
|
||||
raw_result = search.run(query="crew")
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"):
|
||||
agent_text = search.format_output_for_agent(raw_result)
|
||||
|
||||
assert raw_result == {"query": "crew", "score": "not-a-float"}
|
||||
assert agent_text == str(raw_result)
|
||||
|
||||
def test_unserializable_typed_output_warns_and_uses_string_agent_text(
|
||||
self,
|
||||
) -> None:
|
||||
class OpaqueOutput(BaseModel):
|
||||
value: object
|
||||
|
||||
raw_result = OpaqueOutput(value=object())
|
||||
|
||||
@tool("opaque", output_schema=OpaqueOutput)
|
||||
def opaque() -> OpaqueOutput:
|
||||
"""Return an opaque object."""
|
||||
return raw_result
|
||||
|
||||
result = opaque.run()
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"):
|
||||
agent_text = opaque.format_output_for_agent(result)
|
||||
|
||||
assert result is raw_result
|
||||
assert agent_text == str(raw_result)
|
||||
|
||||
def test_output_schema_behavior_carries_over_to_structured_tool(self) -> None:
|
||||
structured = ExplicitSearchTool().to_structured_tool()
|
||||
|
||||
raw_result = structured.invoke({"query": "crew"})
|
||||
|
||||
assert raw_result == {"query": "crew", "score": 0.8}
|
||||
assert json.loads(structured.format_output_for_agent(raw_result)) == {
|
||||
"query": "crew",
|
||||
"score": 0.8,
|
||||
}
|
||||
|
||||
# Async arun() Schema Validation Tests
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -86,6 +88,116 @@ def test_from_function(basic_function):
|
||||
assert isinstance(tool.args_schema, type(BaseModel))
|
||||
|
||||
|
||||
class StructuredOutput(BaseModel):
|
||||
value: str
|
||||
count: int
|
||||
|
||||
|
||||
class StructuredOutputList(RootModel[list[StructuredOutput]]):
|
||||
pass
|
||||
|
||||
|
||||
def _build_explicit_structured_value(value: str) -> dict[str, object]:
|
||||
"""Build a value."""
|
||||
return {"value": value, "count": 1}
|
||||
|
||||
|
||||
def _build_inferred_structured_value(value: str) -> StructuredOutput:
|
||||
"""Build a value."""
|
||||
return StructuredOutput(value=value, count=1)
|
||||
|
||||
|
||||
def _build_structured_values(value: str) -> StructuredOutputList:
|
||||
"""Build values."""
|
||||
return StructuredOutputList([StructuredOutput(value=value, count=1)])
|
||||
|
||||
|
||||
def _build_plain_structured_value(value: str) -> dict[str, object]:
|
||||
"""Build a value."""
|
||||
return {"value": value, "count": 1}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("func", "output_schema", "expected_raw", "expected_agent_payload"),
|
||||
[
|
||||
pytest.param(
|
||||
_build_explicit_structured_value,
|
||||
StructuredOutput,
|
||||
{"value": "crew", "count": 1},
|
||||
{"value": "crew", "count": 1},
|
||||
id="explicit-schema",
|
||||
),
|
||||
pytest.param(
|
||||
_build_inferred_structured_value,
|
||||
None,
|
||||
StructuredOutput(value="crew", count=1),
|
||||
{"value": "crew", "count": 1},
|
||||
id="inferred-base-model",
|
||||
),
|
||||
pytest.param(
|
||||
_build_structured_values,
|
||||
None,
|
||||
StructuredOutputList([StructuredOutput(value="crew", count=1)]),
|
||||
[{"value": "crew", "count": 1}],
|
||||
id="inferred-root-model",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_from_function_returns_raw_result_and_json_agent_text(
|
||||
func,
|
||||
output_schema,
|
||||
expected_raw,
|
||||
expected_agent_payload,
|
||||
):
|
||||
"""Typed structured tools return raw values and format JSON for the agent."""
|
||||
kwargs = {"output_schema": output_schema} if output_schema is not None else {}
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=func,
|
||||
name="build_value",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
raw_result = tool.invoke({"value": "crew"})
|
||||
|
||||
assert raw_result == expected_raw
|
||||
assert json.loads(tool.format_output_for_agent(raw_result)) == (
|
||||
expected_agent_payload
|
||||
)
|
||||
|
||||
|
||||
def test_from_function_does_not_infer_non_pydantic_output_schema():
|
||||
"""Non-Pydantic return annotations use the plain string formatter."""
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=_build_plain_structured_value,
|
||||
name="build_value",
|
||||
)
|
||||
|
||||
raw_result = tool.invoke({"value": "crew"})
|
||||
|
||||
assert raw_result == {"value": "crew", "count": 1}
|
||||
assert tool.format_output_for_agent(raw_result) == str(raw_result)
|
||||
|
||||
|
||||
def test_invalid_typed_output_warns_and_uses_string_agent_text():
|
||||
"""Invalid structured output leaves the raw result unchanged."""
|
||||
def build_value(value: str) -> dict[str, object]:
|
||||
"""Build a value."""
|
||||
return {"value": value, "count": "wrong"}
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=build_value,
|
||||
name="build_value",
|
||||
output_schema=StructuredOutput,
|
||||
)
|
||||
raw_result = tool.invoke({"value": "crew"})
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"):
|
||||
agent_text = tool.format_output_for_agent(raw_result)
|
||||
|
||||
assert raw_result == {"value": "crew", "count": "wrong"}
|
||||
assert agent_text == str(raw_result)
|
||||
|
||||
|
||||
def test_validate_function_signature(basic_function, schema_class):
|
||||
"""Test function signature validation"""
|
||||
tool = CrewStructuredTool(
|
||||
|
||||
Reference in New Issue
Block a user