diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 31c5009bd..fde05ab3e 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -33,6 +33,8 @@ from typing_extensions import TypeIs from crewai.tools.structured_tool import ( CrewStructuredTool, _deserialize_schema, + _format_tool_output_for_agent, + _infer_output_schema_from_callable, _serialize_schema, build_schema_hint, ) @@ -149,6 +151,11 @@ class BaseTool(BaseModel, ABC): validate_default=True, description="The schema for the arguments that the tool accepts.", ) + output_schema: type[PydanticBaseModel] | None = Field( + default=None, + validate_default=True, + description="The schema for the output that the tool returns.", + ) @field_serializer("args_schema", when_used="json") def _serialize_args_schema( @@ -156,6 +163,12 @@ class BaseTool(BaseModel, ABC): ) -> dict[str, Any] | None: return _serialize_schema(schema) + @field_serializer("output_schema", when_used="json") + def _serialize_output_schema( + self, schema: type[PydanticBaseModel] | None + ) -> dict[str, Any] | None: + return _serialize_schema(schema) + description_updated: bool = Field( default=False, description="Flag to check if the description has been updated." ) @@ -233,6 +246,17 @@ class BaseTool(BaseModel, ABC): return create_model(f"{cls.__name__}Schema", **fields) + @field_validator("output_schema", mode="before") + @classmethod + def _default_output_schema( + cls, v: type[PydanticBaseModel] | dict[str, Any] | None + ) -> type[PydanticBaseModel] | None: + if isinstance(v, dict): + return _deserialize_schema(v) + if v is not None: + return v + return _infer_output_schema_from_callable(cls._run) + @field_validator("max_usage_count", mode="before") @classmethod def validate_max_usage_count(cls, v: int | None) -> int | None: @@ -340,6 +364,10 @@ class BaseTool(BaseModel, ABC): "Override _arun for async support or use run() for sync execution." ) + def format_output_for_agent(self, raw_result: Any) -> str: + """Format a raw tool result into the string representation sent to an agent.""" + return _format_tool_output_for_agent(self, raw_result) + def reset_usage_count(self) -> None: """Reset the current usage count to zero.""" self.current_usage_count = 0 @@ -369,6 +397,7 @@ class BaseTool(BaseModel, ABC): name=self.name, description=self.description, args_schema=self.args_schema, + output_schema=self.output_schema, func=self._run, result_as_answer=self.result_as_answer, max_usage_count=self.max_usage_count, @@ -390,6 +419,9 @@ class BaseTool(BaseModel, ABC): raise ValueError("The provided tool must have a callable 'func' attribute.") args_schema = getattr(tool, "args_schema", None) + output_schema = getattr(tool, "output_schema", None) + if output_schema is None: + output_schema = _infer_output_schema_from_callable(tool.func) if args_schema is None: func_signature = signature(tool.func) @@ -420,6 +452,7 @@ class BaseTool(BaseModel, ABC): description=getattr(tool, "description", ""), func=tool.func, args_schema=args_schema, + output_schema=output_schema, ) def _set_args_schema(self) -> None: @@ -568,6 +601,9 @@ class Tool(BaseTool, Generic[P, R]): raise ValueError("The provided tool must have a callable 'func' attribute.") args_schema = getattr(tool, "args_schema", None) + output_schema = getattr(tool, "output_schema", None) + if output_schema is None: + output_schema = _infer_output_schema_from_callable(tool.func) if args_schema is None: func_signature = signature(tool.func) @@ -598,6 +634,7 @@ class Tool(BaseTool, Generic[P, R]): description=getattr(tool, "description", ""), func=tool.func, args_schema=args_schema, + output_schema=output_schema, ) @@ -621,6 +658,7 @@ def tool( name: str, /, *, + output_schema: type[BaseModel] | None = ..., result_as_answer: bool = ..., max_usage_count: int | None = ..., ) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ... @@ -629,6 +667,7 @@ def tool( @overload def tool( *, + output_schema: type[BaseModel] | None = ..., result_as_answer: bool = ..., max_usage_count: int | None = ..., ) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ... @@ -636,6 +675,7 @@ def tool( def tool( *args: Callable[P2, R2] | str, + output_schema: type[BaseModel] | None = None, result_as_answer: bool = False, max_usage_count: int | None = None, ) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]: @@ -649,6 +689,7 @@ def tool( Args: *args: Either the function to decorate or a custom tool name. result_as_answer: If True, the tool result becomes the final agent answer. + output_schema: Optional schema for the output that the tool returns. max_usage_count: Maximum times this tool can be used. None means unlimited. Returns: @@ -690,12 +731,16 @@ def tool( class_name = "".join(tool_name.split()).title() args_schema = create_model(class_name, **fields) + resolved_output_schema = ( + output_schema or _infer_output_schema_from_callable(f) + ) return Tool( name=tool_name, description=f.__doc__, func=f, args_schema=args_schema, + output_schema=resolved_output_schema, result_as_answer=result_as_answer, max_usage_count=max_usage_count, current_usage_count=0, diff --git a/lib/crewai/src/crewai/tools/structured_tool.py b/lib/crewai/src/crewai/tools/structured_tool.py index 6c24f52dc..1151a749a 100644 --- a/lib/crewai/src/crewai/tools/structured_tool.py +++ b/lib/crewai/src/crewai/tools/structured_tool.py @@ -5,7 +5,8 @@ from collections.abc import Callable import inspect import json import textwrap -from typing import TYPE_CHECKING, Annotated, Any, get_type_hints +from typing import TYPE_CHECKING, Annotated, Any, cast, get_type_hints +import warnings from pydantic import ( BaseModel, @@ -36,6 +37,47 @@ def _deserialize_schema(v: Any) -> type[BaseModel] | None: return None +def _infer_output_schema_from_callable( + func: Callable[..., Any], +) -> type[BaseModel] | None: + try: + return_annotation = get_type_hints(func).get("return", inspect.Signature.empty) + except Exception: + return_annotation = inspect.signature(func).return_annotation + + if isinstance(return_annotation, type) and issubclass(return_annotation, BaseModel): + return return_annotation + + return None + + +def _format_tool_output_for_agent(tool: Any, raw_result: Any) -> str: + output_schema = getattr(tool, "output_schema", None) + if output_schema is None: + return str(raw_result) + + try: + validation_input = raw_result + if isinstance(raw_result, BaseModel) and not isinstance( + raw_result, output_schema + ): + validation_input = raw_result.model_dump() + + validated = output_schema.model_validate(validation_input) + return cast(str, validated.model_dump_json()) + except Exception as exc: + warnings.warn( + ( + f"Failed to validate or serialize output from tool " + f"'{getattr(tool, 'name', '')}' using output_schema " + f"'{output_schema.__name__}': {exc}. Falling back to str(raw_result)." + ), + RuntimeWarning, + stacklevel=2, + ) + return str(raw_result) + + if TYPE_CHECKING: pass @@ -81,6 +123,11 @@ class CrewStructuredTool(BaseModel): BeforeValidator(_deserialize_schema), PlainSerializer(_serialize_schema), ] = Field(default=None) + output_schema: Annotated[ + type[BaseModel] | None, + BeforeValidator(_deserialize_schema), + PlainSerializer(_serialize_schema), + ] = Field(default=None) func: Any = Field(default=None, exclude=True) result_as_answer: bool = Field(default=False) max_usage_count: int | None = Field(default=None) @@ -103,6 +150,7 @@ class CrewStructuredTool(BaseModel): description: str | None = None, return_direct: bool = False, args_schema: type[BaseModel] | None = None, + output_schema: type[BaseModel] | None = None, infer_schema: bool = True, **kwargs: Any, ) -> CrewStructuredTool: @@ -114,6 +162,7 @@ class CrewStructuredTool(BaseModel): description: The description of the tool. Defaults to the function docstring return_direct: Whether to return the output directly args_schema: Optional schema for the function arguments + output_schema: Optional schema for the function output infer_schema: Whether to infer the schema from the function signature **kwargs: Additional arguments to pass to the tool @@ -149,10 +198,16 @@ class CrewStructuredTool(BaseModel): name=name, description=description, args_schema=schema, + output_schema=output_schema or _infer_output_schema_from_callable(func), func=func, result_as_answer=return_direct, + **kwargs, ) + def format_output_for_agent(self, raw_result: Any) -> str: + """Format a raw tool result into the string representation sent to an agent.""" + return _format_tool_output_for_agent(self, raw_result) + @staticmethod def _create_schema_from_function( name: str, diff --git a/lib/crewai/tests/tools/test_base_tool.py b/lib/crewai/tests/tools/test_base_tool.py index 7648ad73b..dcf9188c1 100644 --- a/lib/crewai/tests/tools/test_base_tool.py +++ b/lib/crewai/tests/tools/test_base_tool.py @@ -1,12 +1,13 @@ import asyncio from collections.abc import Callable +import json from unittest.mock import patch from crewai.agent import Agent from crewai.crew import Crew from crewai.task import Task from crewai.tools import BaseTool, tool -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, RootModel import pytest @@ -351,6 +352,240 @@ class TestToolDecoratorRunValidation: assert result == "Hello, World!" +class SearchOutput(BaseModel): + query: str + score: float + + +class SearchResults(RootModel[list[SearchOutput]]): + pass + + +class ExplicitSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + output_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> dict[str, object]: + return {"query": query, "score": 0.8} + + +class InferredSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.7) + + +class RootSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + + def _run(self, query: str) -> SearchResults: + return SearchResults([SearchOutput(query=query, score=1.0)]) + + +class DictAnnotatedSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + + def _run(self, query: str) -> dict[str, object]: + return {"query": query, "score": 0.5} + + +def _make_explicit_decorator_tool() -> BaseTool: + @tool("search", output_schema=SearchOutput) + def search(query: str) -> dict[str, object]: + """Search for a query.""" + return {"query": query, "score": 0.8} + + return search + + +def _make_inferred_decorator_tool() -> BaseTool: + @tool("search") + def search(query: str) -> SearchOutput: + """Search for a query.""" + return SearchOutput(query=query, score=0.6) + + return search + + +def _make_root_decorator_tool() -> BaseTool: + @tool("search") + def search(query: str) -> SearchResults: + """Search for a query.""" + return SearchResults([SearchOutput(query=query, score=1.0)]) + + return search + + +class TestToolOutputSchema: + """Tests for typed tool output behavior.""" + + @pytest.mark.parametrize( + ("tool_cls", "expected_raw", "expected_agent_payload"), + [ + pytest.param( + ExplicitSearchTool, + {"query": "crew", "score": 0.8}, + {"query": "crew", "score": 0.8}, + id="explicit-schema", + ), + pytest.param( + InferredSearchTool, + SearchOutput(query="crew", score=0.7), + {"query": "crew", "score": 0.7}, + id="inferred-base-model", + ), + pytest.param( + RootSearchTool, + SearchResults([SearchOutput(query="crew", score=1.0)]), + [{"query": "crew", "score": 1.0}], + id="inferred-root-model", + ), + ], + ) + def test_base_tools_return_raw_result_and_json_agent_text( + self, + tool_cls: type[BaseTool], + expected_raw: object, + expected_agent_payload: object, + ) -> None: + t = tool_cls() + + raw_result = t.run(query="crew") + + assert raw_result == expected_raw + assert json.loads(t.format_output_for_agent(raw_result)) == ( + expected_agent_payload + ) + + def test_base_tool_does_not_infer_non_pydantic_return_annotation(self) -> None: + t = DictAnnotatedSearchTool() + + raw_result = t.run(query="crew") + + assert raw_result == {"query": "crew", "score": 0.5} + assert t.format_output_for_agent(raw_result) == str(raw_result) + + @pytest.mark.parametrize( + ("make_tool", "expected_raw", "expected_agent_payload"), + [ + pytest.param( + _make_explicit_decorator_tool, + {"query": "crew", "score": 0.8}, + {"query": "crew", "score": 0.8}, + id="explicit-schema", + ), + pytest.param( + _make_inferred_decorator_tool, + SearchOutput(query="crew", score=0.6), + {"query": "crew", "score": 0.6}, + id="inferred-base-model", + ), + pytest.param( + _make_root_decorator_tool, + SearchResults([SearchOutput(query="crew", score=1.0)]), + [{"query": "crew", "score": 1.0}], + id="inferred-root-model", + ), + ], + ) + def test_decorator_tools_return_raw_result_and_json_agent_text( + self, + make_tool: Callable[[], BaseTool], + expected_raw: object, + expected_agent_payload: object, + ) -> None: + search = make_tool() + + raw_result = search.run(query="crew") + + assert raw_result == expected_raw + assert json.loads(search.format_output_for_agent(raw_result)) == ( + expected_agent_payload + ) + + def test_decorator_tool_does_not_infer_non_pydantic_return_annotation( + self, + ) -> None: + @tool("search") + def search(query: str) -> dict[str, object]: + """Search for a query.""" + return {"query": query, "score": 0.5} + + raw_result = search.run(query="crew") + + assert raw_result == {"query": "crew", "score": 0.5} + assert search.format_output_for_agent(raw_result) == str(raw_result) + + def test_explicit_output_schema_wins_over_return_annotation(self) -> None: + class AlternateOutput(BaseModel): + value: str + + @tool("search", output_schema=AlternateOutput) + def search(query: str) -> SearchOutput: + """Search for a query.""" + return SearchOutput(query=query, score=0.6) + + raw_result = search.run(query="crew") + + with pytest.warns(RuntimeWarning, match="AlternateOutput"): + agent_text = search.format_output_for_agent(raw_result) + + assert raw_result == SearchOutput(query="crew", score=0.6) + assert agent_text == str(raw_result) + + def test_invalid_typed_output_warns_and_uses_string_agent_text( + self, + ) -> None: + @tool("search", output_schema=SearchOutput) + def search(query: str) -> dict[str, object]: + """Search for a query.""" + return {"query": query, "score": "not-a-float"} + + raw_result = search.run(query="crew") + + with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"): + agent_text = search.format_output_for_agent(raw_result) + + assert raw_result == {"query": "crew", "score": "not-a-float"} + assert agent_text == str(raw_result) + + def test_unserializable_typed_output_warns_and_uses_string_agent_text( + self, + ) -> None: + class OpaqueOutput(BaseModel): + value: object + + raw_result = OpaqueOutput(value=object()) + + @tool("opaque", output_schema=OpaqueOutput) + def opaque() -> OpaqueOutput: + """Return an opaque object.""" + return raw_result + + result = opaque.run() + + with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"): + agent_text = opaque.format_output_for_agent(result) + + assert result is raw_result + assert agent_text == str(raw_result) + + def test_output_schema_behavior_carries_over_to_structured_tool(self) -> None: + structured = ExplicitSearchTool().to_structured_tool() + + raw_result = structured.invoke({"query": "crew"}) + + assert raw_result == {"query": "crew", "score": 0.8} + assert json.loads(structured.format_output_for_agent(raw_result)) == { + "query": "crew", + "score": 0.8, + } + # Async arun() Schema Validation Tests diff --git a/lib/crewai/tests/tools/test_structured_tool.py b/lib/crewai/tests/tools/test_structured_tool.py index 27c463d47..2a81911ae 100644 --- a/lib/crewai/tests/tools/test_structured_tool.py +++ b/lib/crewai/tests/tools/test_structured_tool.py @@ -1,5 +1,7 @@ +import json + from crewai.tools.structured_tool import CrewStructuredTool -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, RootModel import pytest @@ -86,6 +88,116 @@ def test_from_function(basic_function): assert isinstance(tool.args_schema, type(BaseModel)) +class StructuredOutput(BaseModel): + value: str + count: int + + +class StructuredOutputList(RootModel[list[StructuredOutput]]): + pass + + +def _build_explicit_structured_value(value: str) -> dict[str, object]: + """Build a value.""" + return {"value": value, "count": 1} + + +def _build_inferred_structured_value(value: str) -> StructuredOutput: + """Build a value.""" + return StructuredOutput(value=value, count=1) + + +def _build_structured_values(value: str) -> StructuredOutputList: + """Build values.""" + return StructuredOutputList([StructuredOutput(value=value, count=1)]) + + +def _build_plain_structured_value(value: str) -> dict[str, object]: + """Build a value.""" + return {"value": value, "count": 1} + + +@pytest.mark.parametrize( + ("func", "output_schema", "expected_raw", "expected_agent_payload"), + [ + pytest.param( + _build_explicit_structured_value, + StructuredOutput, + {"value": "crew", "count": 1}, + {"value": "crew", "count": 1}, + id="explicit-schema", + ), + pytest.param( + _build_inferred_structured_value, + None, + StructuredOutput(value="crew", count=1), + {"value": "crew", "count": 1}, + id="inferred-base-model", + ), + pytest.param( + _build_structured_values, + None, + StructuredOutputList([StructuredOutput(value="crew", count=1)]), + [{"value": "crew", "count": 1}], + id="inferred-root-model", + ), + ], +) +def test_from_function_returns_raw_result_and_json_agent_text( + func, + output_schema, + expected_raw, + expected_agent_payload, +): + """Typed structured tools return raw values and format JSON for the agent.""" + kwargs = {"output_schema": output_schema} if output_schema is not None else {} + tool = CrewStructuredTool.from_function( + func=func, + name="build_value", + **kwargs, + ) + + raw_result = tool.invoke({"value": "crew"}) + + assert raw_result == expected_raw + assert json.loads(tool.format_output_for_agent(raw_result)) == ( + expected_agent_payload + ) + + +def test_from_function_does_not_infer_non_pydantic_output_schema(): + """Non-Pydantic return annotations use the plain string formatter.""" + tool = CrewStructuredTool.from_function( + func=_build_plain_structured_value, + name="build_value", + ) + + raw_result = tool.invoke({"value": "crew"}) + + assert raw_result == {"value": "crew", "count": 1} + assert tool.format_output_for_agent(raw_result) == str(raw_result) + + +def test_invalid_typed_output_warns_and_uses_string_agent_text(): + """Invalid structured output leaves the raw result unchanged.""" + def build_value(value: str) -> dict[str, object]: + """Build a value.""" + return {"value": value, "count": "wrong"} + + tool = CrewStructuredTool.from_function( + func=build_value, + name="build_value", + output_schema=StructuredOutput, + ) + raw_result = tool.invoke({"value": "crew"}) + + with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"): + agent_text = tool.format_output_for_agent(raw_result) + + assert raw_result == {"value": "crew", "count": "wrong"} + assert agent_text == str(raw_result) + + def test_validate_function_signature(basic_function, schema_class): """Test function signature validation""" tool = CrewStructuredTool(