From 9db2d4476641fda3e13347662152214d528f4d79 Mon Sep 17 00:00:00 2001 From: Vinicius Brasil Date: Fri, 19 Jun 2026 14:33:51 -0700 Subject: [PATCH] Add typed output schemas for CrewAI tools (#6236) Currently, tools have a strong input contract through `args_schema`, but no output contract. This means that anything a tool outputs is converted to string. Not only the contract is weak, but the "invisible" conversion to string can have unexpected effects when the tool returns complex objects like dicts and arrays. With this PR, a tool can _optionally_ define an output contract with `output_schema`. CrewAI validates the raw result and sends the agent JSON. ```python class ProductResult(BaseModel): sku: str name: str in_stock: bool class ProductLookupTool(BaseTool): name: str = "Product Lookup" description: str = "Look up product availability by SKU." def _run(self, sku: str) -> ProductResult: return ProductResult(sku=sku, name="USB-C dock", in_stock=True) ``` If the result does not match the schema, CrewAI warns and falls back to `str(raw_result)` instead of failing the run: ```python @tool("Product Lookup", output_schema=ProductResult) def product_lookup(sku: str) -> dict[str, object]: return {"sku": sku, "name": "USB-C dock", "in_stock": True} #=> RuntimeWarning: Failed to validate or serialize output from tool 'Bad Product Lookup' using output_schema 'ProductResult'... Falling back to str(raw_result). ``` This is additive and non-breaking. Existing tools do not need to change. Tools without `output_schema` keep the old string behavior. Invalid typed outputs warn and fall back to the old formatting path. --- docs/edge/en/concepts/tools.mdx | 50 ++++ .../en/guides/tools/publish-custom-tools.mdx | 65 ++++- docs/edge/en/learn/create-custom-tools.mdx | 105 +++++++ docs/edge/en/learn/execution-hooks.mdx | 5 +- docs/edge/en/learn/tool-hooks.mdx | 5 +- .../src/crewai/agents/crew_agent_executor.py | 52 ++-- lib/crewai/src/crewai/agents/tools_handler.py | 5 +- .../src/crewai/experimental/agent_executor.py | 53 ++-- lib/crewai/src/crewai/hooks/tool_hooks.py | 5 + lib/crewai/src/crewai/tools/base_tool.py | 45 +++ .../src/crewai/tools/structured_tool.py | 62 ++++- lib/crewai/src/crewai/tools/tool_usage.py | 35 ++- .../src/crewai/utilities/agent_utils.py | 47 ++-- lib/crewai/src/crewai/utilities/tool_utils.py | 4 + .../tests/agents/test_native_tool_calling.py | 73 ++++- lib/crewai/tests/hooks/test_tool_hooks.py | 4 + lib/crewai/tests/tools/test_base_tool.py | 259 +++++++++++++++++- .../tests/tools/test_structured_tool.py | 116 +++++++- lib/crewai/tests/tools/test_tool_usage.py | 144 ++++++++++ .../tests/utilities/test_agent_utils.py | 143 ++++++++++ 20 files changed, 1200 insertions(+), 77 deletions(-) diff --git a/docs/edge/en/concepts/tools.mdx b/docs/edge/en/concepts/tools.mdx index 52e568073..da41da3b1 100644 --- a/docs/edge/en/concepts/tools.mdx +++ b/docs/edge/en/concepts/tools.mdx @@ -39,6 +39,7 @@ The Enterprise Tools Repository includes: - **Error Handling**: Incorporates robust error handling mechanisms to ensure smooth operation. - **Caching Mechanism**: Features intelligent caching to optimize performance and reduce redundant operations. - **Asynchronous Support**: Handles both synchronous and asynchronous tools, enabling non-blocking operations. +- **Typed Outputs**: Uses optional Pydantic models to give agents clear JSON fields while direct Python calls still receive the tool's normal return value. ## Using CrewAI Tools @@ -184,6 +185,55 @@ class MyCustomTool(BaseTool): return "Tool's result" ``` +### Typed Tool Outputs + +When a tool returns structured data, define a Pydantic output model. This gives the agent field names it can trust, such as `sku`, `quantity`, or `needs_reorder`. + +Direct Python calls still receive the value your tool returns. When an agent uses the tool, CrewAI sends the agent a JSON string based on the output model. + +```python Code +from crewai.tools import BaseTool +from pydantic import BaseModel + +class InventoryResult(BaseModel): + sku: str + quantity: int + needs_reorder: bool + +class InventoryTool(BaseTool): + name: str = "Inventory Check" + description: str = "Checks current stock for a product SKU." + + def _run(self, sku: str) -> InventoryResult: + quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0) + return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5) + +tool = InventoryTool() + +# Direct calls receive the raw Pydantic object. +result = tool.run(sku="SKU-123") +print(result.quantity) +``` + +To send Markdown or another short text format to the agent, override `format_output_for_agent`. Direct calls to `tool.run(...)` still return the normal Python value. + +```python Code +class InventoryTool(BaseTool): + name: str = "Inventory Check" + description: str = "Checks current stock for a product SKU." + + def _run(self, sku: str) -> InventoryResult: + quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0) + return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5) + + def format_output_for_agent(self, raw_result: object) -> str: + result = InventoryResult.model_validate(raw_result) + status = "reorder needed" if result.needs_reorder else "stock is healthy" + return f"{result.sku}: {result.quantity} units. {status}." +``` + +If you do not override `format_output_for_agent`, typed outputs are sent to the agent as JSON. Plain string results work as before. + ## Asynchronous Tool Support CrewAI supports asynchronous tools, allowing you to implement tools that perform non-blocking operations like network requests, file I/O, or other async operations without blocking the main execution thread. diff --git a/docs/edge/en/guides/tools/publish-custom-tools.mdx b/docs/edge/en/guides/tools/publish-custom-tools.mdx index 973856816..71dbfd622 100644 --- a/docs/edge/en/guides/tools/publish-custom-tools.mdx +++ b/docs/edge/en/guides/tools/publish-custom-tools.mdx @@ -65,7 +65,7 @@ Regardless of which approach you use, your tool must: - Have a **`description`** — tells the agent when and how to use the tool. This directly affects how well agents use your tool, so be clear and specific. - Implement **`_run`** (BaseTool) or provide a **function body** (@tool) — the synchronous execution logic. - Use **type annotations** on all parameters and return values. -- Return a **string** result (or something that can be meaningfully converted to one). +- Return a **string** result, or define an optional Pydantic output schema for structured results. ### Optional: Async Support @@ -104,6 +104,67 @@ class TranslateInput(BaseModel): Explicit schemas are recommended for published tools — they produce better agent behavior and clearer documentation for your users. +### Optional: Typed Outputs with `result_schema` + +If your tool returns structured data, define a Pydantic output model. This is a good default for published tools because users and agents can rely on named fields. + +Direct Python calls still receive the value your tool returns. When an agent uses the tool, CrewAI sends the agent JSON based on the output model. + +CrewAI can infer the output schema from a Pydantic return annotation: + +```python +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + + +class GeolocateResult(BaseModel): + latitude: float = Field(..., description="Latitude in decimal degrees.") + longitude: float = Field(..., description="Longitude in decimal degrees.") + + +class GeolocateTool(BaseTool): + name: str = "Geolocate" + description: str = "Converts a street address into latitude/longitude coordinates." + + def _run(self, address: str) -> GeolocateResult: + if "1600 Pennsylvania" in address: + return GeolocateResult(latitude=38.8977, longitude=-77.0365) + return GeolocateResult(latitude=40.7128, longitude=-74.0060) +``` + +Set `result_schema` explicitly when your tool returns a dictionary: + +```python +class GeolocateTool(BaseTool): + name: str = "Geolocate" + description: str = "Converts a street address into latitude/longitude coordinates." + result_schema: type[BaseModel] = GeolocateResult + + def _run(self, address: str) -> dict[str, float]: + if "1600 Pennsylvania" in address: + return {"latitude": 38.8977, "longitude": -77.0365} + return {"latitude": 40.7128, "longitude": -74.0060} +``` + +If agents should receive a short text summary instead of JSON, override `format_output_for_agent` on your `BaseTool` subclass. + +```python +class GeolocateTool(BaseTool): + name: str = "Geolocate" + description: str = "Converts a street address into latitude/longitude coordinates." + + def _run(self, address: str) -> GeolocateResult: + if "1600 Pennsylvania" in address: + return GeolocateResult(latitude=38.8977, longitude=-77.0365) + return GeolocateResult(latitude=40.7128, longitude=-74.0060) + + def format_output_for_agent(self, raw_result: object) -> str: + result = GeolocateResult.model_validate(raw_result) + return f"Latitude {result.latitude}, longitude {result.longitude}" +``` + +The override only changes what the agent sees. Direct users of your package still receive the normal value from `tool.run(...)`. + ### Optional: Environment Variables If your tool requires API keys or other configuration, declare them with `env_vars` so users know what to set: @@ -241,4 +302,4 @@ agent = Agent( tools=[GeolocateTool()], # ... ) -``` \ No newline at end of file +``` diff --git a/docs/edge/en/learn/create-custom-tools.mdx b/docs/edge/en/learn/create-custom-tools.mdx index c1246f3fc..78205bf99 100644 --- a/docs/edge/en/learn/create-custom-tools.mdx +++ b/docs/edge/en/learn/create-custom-tools.mdx @@ -53,6 +53,111 @@ def my_simple_tool(question: str) -> str: return "Tool output" ``` +### Best Practice: Define Typed Outputs + +When a tool returns structured data, define a Pydantic output model. This helps the agent read the result as clear fields instead of guessing from plain text. + +Typed outputs are useful for results with stable fields, such as IDs, status values, scores, prices, or lists. Plain strings are still fine for short prose results. + +Direct Python calls still receive the value your tool returns. When an agent uses a typed tool, CrewAI sends the agent JSON based on the output model. + +#### Return a Pydantic Model + +CrewAI infers the output schema when your `BaseTool` has a Pydantic return annotation. + +```python Code +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + +class InventoryResult(BaseModel): + sku: str = Field(description="The product SKU.") + quantity: int = Field(description="Units available.") + needs_reorder: bool = Field(description="Whether the item should be reordered.") + +class InventoryTool(BaseTool): + name: str = "Inventory Check" + description: str = "Check current stock for a product SKU." + + def _run(self, sku: str) -> InventoryResult: + quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0) + return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5) + +tool = InventoryTool() +result = tool.run(sku="SKU-123") + +# Direct Python calls receive the raw Pydantic object. +print(result.quantity) +``` + +When an agent calls `InventoryTool`, it receives JSON like this: + +```json +{"sku":"SKU-123","quantity":14,"needs_reorder":false} +``` + +#### Use `result_schema` with Dictionary Results + +If your tool returns a dictionary, set `result_schema` explicitly. You can do this on a `BaseTool` subclass or with the `@tool` decorator: + +```python Code +from crewai.tools import tool +from pydantic import BaseModel, Field + +class ProductResult(BaseModel): + sku: str = Field(description="The product SKU.") + name: str = Field(description="The product name.") + in_stock: bool = Field(description="Whether the product is available.") + +@tool("Product Lookup", result_schema=ProductResult) +def product_lookup(sku: str) -> dict[str, object]: + """Look up product availability by SKU.""" + catalog = { + "SKU-123": ("Noise-canceling headset", True), + "SKU-456": ("USB-C dock", False), + } + name, in_stock = catalog.get(sku, ("Unknown product", False)) + return { + "sku": sku, + "name": name, + "in_stock": in_stock, + } +``` + +#### Customize the Text Sent to the Agent + +By default, typed tool outputs are sent to the agent as JSON. If the agent should receive a short summary instead, subclass `BaseTool` and override `format_output_for_agent`. + +```python Code +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + +class InventoryResult(BaseModel): + sku: str = Field(description="The product SKU.") + quantity: int = Field(description="Units available.") + needs_reorder: bool = Field(description="Whether the item should be reordered.") + +class InventoryTool(BaseTool): + name: str = "Inventory Check" + description: str = "Check current stock for a product SKU." + + def _run(self, sku: str) -> InventoryResult: + quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0) + return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5) + + def format_output_for_agent(self, raw_result: object) -> str: + result = InventoryResult.model_validate(raw_result) + status = "reorder needed" if result.needs_reorder else "stock is healthy" + return f"{result.sku}: {result.quantity} units. {status}." + +tool = InventoryTool() +result = tool.run(sku="SKU-123") + +# Direct Python calls receive the raw Pydantic object. +print(result.quantity) +``` + +The override only changes what the agent sees. Direct calls to `tool.run(...)` still return the normal Python value. + ### Defining a Cache Function for the Tool To optimize tool performance with caching, define custom caching strategies using the `cache_function` attribute. diff --git a/docs/edge/en/learn/execution-hooks.mdx b/docs/edge/en/learn/execution-hooks.mdx index 74234db97..f9f667e4f 100644 --- a/docs/edge/en/learn/execution-hooks.mdx +++ b/docs/edge/en/learn/execution-hooks.mdx @@ -195,9 +195,12 @@ class ToolCallHookContext: agent: Agent | None # Agent executing task: Task | None # Current task crew: Crew | None # Crew instance - tool_result: str | None # Tool result (after hooks) + tool_result: str | None # Agent-facing result string (after hooks) + raw_tool_result: Any | None # Raw Python result (after hooks) ``` +For typed tool outputs, `tool_result` is the string the agent sees. By default, this is JSON. If the tool uses custom formatting, it can be Markdown or another string. `raw_tool_result` is the original Python value returned by the tool. + ## Common Patterns ### Safety and Validation diff --git a/docs/edge/en/learn/tool-hooks.mdx b/docs/edge/en/learn/tool-hooks.mdx index d1d727a5c..489463e81 100644 --- a/docs/edge/en/learn/tool-hooks.mdx +++ b/docs/edge/en/learn/tool-hooks.mdx @@ -60,9 +60,12 @@ class ToolCallHookContext: agent: Agent | BaseAgent | None # Agent executing the tool task: Task | None # Current task crew: Crew | None # Crew instance - tool_result: str | None # Tool result (after hooks only) + tool_result: str | None # Agent-facing result string (after hooks only) + raw_tool_result: Any | None # Raw Python result (after hooks only) ``` +For typed tool outputs, `tool_result` is the string the agent sees. By default, this is JSON. If the tool uses custom formatting, it can be Markdown or another string. Use `raw_tool_result` when your hook needs the typed object or dictionary. + ### Modifying Tool Inputs **Important:** Always modify tool inputs in-place: diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index 92a1ce5fb..de2315e3a 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -57,6 +57,7 @@ from crewai.utilities.agent_utils import ( convert_tools_to_openai_schema, enforce_rpm_limit, format_message_for_llm, + format_native_tool_output_for_agent, get_llm_response, handle_agent_action_core, handle_context_length, @@ -907,19 +908,31 @@ class CrewAgentExecutor(BaseAgentExecutor): ): max_usage_reached = True + structured_tool: CrewStructuredTool | None = None + if original_tool is not None: + for structured in self.tools or []: + if getattr(structured, "_original_tool", None) is original_tool: + structured_tool = structured + break + if structured_tool is None: + for structured in self.tools or []: + if sanitize_tool_name(structured.name) == func_name: + structured_tool = structured + break + + output_tool = original_tool or structured_tool + from_cache = False result: str = "Tool not found" + raw_tool_result: Any = result input_str = json.dumps(args_dict) if args_dict else "" - if self.tools_handler and self.tools_handler.cache: + if self.tools_handler and self.tools_handler.cache and output_tool is not None: cached_result = self.tools_handler.cache.read( tool=func_name, input=input_str ) if cached_result is not None: - result = ( - str(cached_result) - if not isinstance(cached_result, str) - else cached_result - ) + raw_tool_result = cached_result + result = format_native_tool_output_for_agent(output_tool, cached_result) from_cache = True agent_key = getattr(self.agent, "key", "unknown") if self.agent else "unknown" @@ -938,18 +951,6 @@ class CrewAgentExecutor(BaseAgentExecutor): track_delegation_if_needed(func_name, args_dict or {}, self.task) - structured_tool: CrewStructuredTool | None = None - if original_tool is not None: - for structured in self.tools or []: - if getattr(structured, "_original_tool", None) is original_tool: - structured_tool = structured - break - if structured_tool is None: - for structured in self.tools or []: - if sanitize_tool_name(structured.name) == func_name: - structured_tool = structured - break - hook_blocked = False before_hook_context = ToolCallHookContext( tool_name=func_name, @@ -975,11 +976,18 @@ class CrewAgentExecutor(BaseAgentExecutor): if hook_blocked: result = f"Tool execution blocked by hook. Tool: {func_name}" + raw_tool_result = result elif max_usage_reached and original_tool: result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore." - elif not from_cache and func_name in available_functions: + raw_tool_result = result + elif ( + not from_cache + and func_name in available_functions + and output_tool is not None + ): try: raw_result = available_functions[func_name](**(args_dict or {})) + raw_tool_result = raw_result if self.tools_handler and self.tools_handler.cache: should_cache = True @@ -996,11 +1004,10 @@ class CrewAgentExecutor(BaseAgentExecutor): tool=func_name, input=input_str, output=raw_result ) - result = ( - str(raw_result) if not isinstance(raw_result, str) else raw_result - ) + result = format_native_tool_output_for_agent(output_tool, raw_result) except Exception as e: result = f"Error executing tool: {e}" + raw_tool_result = result if self.task: self.task.increment_tools_errors() crewai_event_bus.emit( @@ -1024,6 +1031,7 @@ class CrewAgentExecutor(BaseAgentExecutor): task=self.task, crew=self.crew, tool_result=result, + raw_tool_result=raw_tool_result, ) after_hooks = get_after_tool_call_hooks() try: diff --git a/lib/crewai/src/crewai/agents/tools_handler.py b/lib/crewai/src/crewai/agents/tools_handler.py index 8ab759b85..c56226bc1 100644 --- a/lib/crewai/src/crewai/agents/tools_handler.py +++ b/lib/crewai/src/crewai/agents/tools_handler.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +from typing import Any from pydantic import BaseModel, Field @@ -25,14 +26,14 @@ class ToolsHandler(BaseModel): def on_tool_use( self, calling: ToolCalling | InstructorToolCalling, - output: str, + output: Any, should_cache: bool = True, ) -> None: """Run when tool ends running. Args: calling: The tool calling instance. - output: The output from the tool execution. + output: The raw output from the tool execution. should_cache: Whether to cache the tool output. """ self.last_used_tool = calling diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index c026c7509..303330dc6 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -80,6 +80,7 @@ from crewai.utilities.agent_utils import ( enforce_rpm_limit, extract_tool_call_info, format_message_for_llm, + format_native_tool_output_for_agent, get_llm_response, handle_agent_action_core, handle_context_length, @@ -1905,19 +1906,32 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): ): max_usage_reached = True + structured_tool: CrewStructuredTool | None = None + if original_tool is not None: + for structured in self.tools or []: + if getattr(structured, "_original_tool", None) is original_tool: + structured_tool = structured + break + if structured_tool is None: + for structured in self.tools or []: + if sanitize_tool_name(structured.name) == func_name: + structured_tool = structured + break + + output_tool = original_tool or structured_tool + # Check cache before executing from_cache = False + result = "Tool not found" + raw_tool_result: Any = result input_str = json.dumps(args_dict) if args_dict else "" - if self.tools_handler and self.tools_handler.cache: + if self.tools_handler and self.tools_handler.cache and output_tool is not None: cached_result = self.tools_handler.cache.read( tool=func_name, input=input_str ) if cached_result is not None: - result = ( - str(cached_result) - if not isinstance(cached_result, str) - else cached_result - ) + raw_tool_result = cached_result + result = format_native_tool_output_for_agent(output_tool, cached_result) from_cache = True # Emit tool usage started event @@ -1936,18 +1950,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): track_delegation_if_needed(func_name, args_dict, self.task) - structured_tool: CrewStructuredTool | None = None - if original_tool is not None: - for structured in self.tools or []: - if getattr(structured, "_original_tool", None) is original_tool: - structured_tool = structured - break - if structured_tool is None: - for structured in self.tools or []: - if sanitize_tool_name(structured.name) == func_name: - structured_tool = structured - break - hook_blocked = False before_hook_context = ToolCallHookContext( tool_name=func_name, @@ -1973,12 +1975,13 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): if hook_blocked: result = f"Tool execution blocked by hook. Tool: {func_name}" - elif not from_cache and not max_usage_reached: - result = "Tool not found" + raw_tool_result = result + elif not from_cache and not max_usage_reached and output_tool is not None: if func_name in self._available_functions: try: tool_func = self._available_functions[func_name] raw_result = tool_func(**args_dict) + raw_tool_result = raw_result # Add to cache after successful execution (before string conversion) if self.tools_handler and self.tools_handler.cache: @@ -1992,14 +1995,12 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): tool=func_name, input=input_str, output=raw_result ) - # Convert to string for message - result = ( - str(raw_result) - if not isinstance(raw_result, str) - else raw_result + result = format_native_tool_output_for_agent( + output_tool, raw_result ) except Exception as e: result = f"Error executing tool: {e}" + raw_tool_result = result if self.task: self.task.increment_tools_errors() # Emit tool usage error event @@ -2021,6 +2022,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore." else: result = f"Tool '{func_name}' has reached its maximum usage limit and cannot be used anymore." + raw_tool_result = result # Execute after_tool_call hooks (even if blocked, to allow logging/monitoring) after_hook_context = ToolCallHookContext( @@ -2031,6 +2033,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): task=self.task, crew=self.crew, tool_result=result, + raw_tool_result=raw_tool_result, ) after_hooks = get_after_tool_call_hooks() try: diff --git a/lib/crewai/src/crewai/hooks/tool_hooks.py b/lib/crewai/src/crewai/hooks/tool_hooks.py index b54ca7b10..a860c01d4 100644 --- a/lib/crewai/src/crewai/hooks/tool_hooks.py +++ b/lib/crewai/src/crewai/hooks/tool_hooks.py @@ -40,6 +40,8 @@ class ToolCallHookContext: crew: Crew instance (may be None) tool_result: Tool execution result (only set for after_tool_call hooks). Can be modified by returning a new string from after_tool_call hook. + raw_tool_result: Raw Python tool execution result (only set for + after_tool_call hooks). This is not modified by after hooks. """ def __init__( @@ -51,6 +53,7 @@ class ToolCallHookContext: task: Task | None = None, crew: Crew | None = None, tool_result: str | None = None, + raw_tool_result: Any | None = None, ) -> None: """Initialize tool call hook context. @@ -62,6 +65,7 @@ class ToolCallHookContext: task: Optional current task crew: Optional crew instance tool_result: Optional tool result (for after hooks) + raw_tool_result: Optional raw tool result (for after hooks) """ self.tool_name = tool_name self.tool_input = tool_input @@ -70,6 +74,7 @@ class ToolCallHookContext: self.task = task self.crew = crew self.tool_result = tool_result + self.raw_tool_result = raw_tool_result def request_human_input( self, diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 31c5009bd..c6c3dba15 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -33,6 +33,8 @@ from typing_extensions import TypeIs from crewai.tools.structured_tool import ( CrewStructuredTool, _deserialize_schema, + _format_tool_output_for_agent, + _infer_result_schema_from_callable, _serialize_schema, build_schema_hint, ) @@ -149,6 +151,11 @@ class BaseTool(BaseModel, ABC): validate_default=True, description="The schema for the arguments that the tool accepts.", ) + result_schema: type[PydanticBaseModel] | None = Field( + default=None, + validate_default=True, + description="The schema for the output that the tool returns.", + ) @field_serializer("args_schema", when_used="json") def _serialize_args_schema( @@ -156,6 +163,12 @@ class BaseTool(BaseModel, ABC): ) -> dict[str, Any] | None: return _serialize_schema(schema) + @field_serializer("result_schema", when_used="json") + def _serialize_result_schema( + self, schema: type[PydanticBaseModel] | None + ) -> dict[str, Any] | None: + return _serialize_schema(schema) + description_updated: bool = Field( default=False, description="Flag to check if the description has been updated." ) @@ -233,6 +246,17 @@ class BaseTool(BaseModel, ABC): return create_model(f"{cls.__name__}Schema", **fields) + @field_validator("result_schema", mode="before") + @classmethod + def _default_result_schema( + cls, v: type[PydanticBaseModel] | dict[str, Any] | None + ) -> type[PydanticBaseModel] | None: + if isinstance(v, dict): + return _deserialize_schema(v) + if v is not None: + return v + return _infer_result_schema_from_callable(cls._run) + @field_validator("max_usage_count", mode="before") @classmethod def validate_max_usage_count(cls, v: int | None) -> int | None: @@ -340,6 +364,10 @@ class BaseTool(BaseModel, ABC): "Override _arun for async support or use run() for sync execution." ) + def format_output_for_agent(self, raw_result: Any) -> str: + """Format a raw tool result into the string representation sent to an agent.""" + return _format_tool_output_for_agent(self, raw_result) + def reset_usage_count(self) -> None: """Reset the current usage count to zero.""" self.current_usage_count = 0 @@ -369,6 +397,7 @@ class BaseTool(BaseModel, ABC): name=self.name, description=self.description, args_schema=self.args_schema, + result_schema=self.result_schema, func=self._run, result_as_answer=self.result_as_answer, max_usage_count=self.max_usage_count, @@ -390,6 +419,9 @@ class BaseTool(BaseModel, ABC): raise ValueError("The provided tool must have a callable 'func' attribute.") args_schema = getattr(tool, "args_schema", None) + result_schema = getattr(tool, "result_schema", None) + if result_schema is None: + result_schema = _infer_result_schema_from_callable(tool.func) if args_schema is None: func_signature = signature(tool.func) @@ -420,6 +452,7 @@ class BaseTool(BaseModel, ABC): description=getattr(tool, "description", ""), func=tool.func, args_schema=args_schema, + result_schema=result_schema, ) def _set_args_schema(self) -> None: @@ -568,6 +601,9 @@ class Tool(BaseTool, Generic[P, R]): raise ValueError("The provided tool must have a callable 'func' attribute.") args_schema = getattr(tool, "args_schema", None) + result_schema = getattr(tool, "result_schema", None) + if result_schema is None: + result_schema = _infer_result_schema_from_callable(tool.func) if args_schema is None: func_signature = signature(tool.func) @@ -598,6 +634,7 @@ class Tool(BaseTool, Generic[P, R]): description=getattr(tool, "description", ""), func=tool.func, args_schema=args_schema, + result_schema=result_schema, ) @@ -621,6 +658,7 @@ def tool( name: str, /, *, + result_schema: type[BaseModel] | None = ..., result_as_answer: bool = ..., max_usage_count: int | None = ..., ) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ... @@ -629,6 +667,7 @@ def tool( @overload def tool( *, + result_schema: type[BaseModel] | None = ..., result_as_answer: bool = ..., max_usage_count: int | None = ..., ) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ... @@ -636,6 +675,7 @@ def tool( def tool( *args: Callable[P2, R2] | str, + result_schema: type[BaseModel] | None = None, result_as_answer: bool = False, max_usage_count: int | None = None, ) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]: @@ -649,6 +689,7 @@ def tool( Args: *args: Either the function to decorate or a custom tool name. result_as_answer: If True, the tool result becomes the final agent answer. + result_schema: Optional schema for the output that the tool returns. max_usage_count: Maximum times this tool can be used. None means unlimited. Returns: @@ -690,12 +731,16 @@ def tool( class_name = "".join(tool_name.split()).title() args_schema = create_model(class_name, **fields) + resolved_result_schema = ( + result_schema or _infer_result_schema_from_callable(f) + ) return Tool( name=tool_name, description=f.__doc__, func=f, args_schema=args_schema, + result_schema=resolved_result_schema, result_as_answer=result_as_answer, max_usage_count=max_usage_count, current_usage_count=0, diff --git a/lib/crewai/src/crewai/tools/structured_tool.py b/lib/crewai/src/crewai/tools/structured_tool.py index 6c24f52dc..8ecba8549 100644 --- a/lib/crewai/src/crewai/tools/structured_tool.py +++ b/lib/crewai/src/crewai/tools/structured_tool.py @@ -5,7 +5,8 @@ from collections.abc import Callable import inspect import json import textwrap -from typing import TYPE_CHECKING, Annotated, Any, get_type_hints +from typing import TYPE_CHECKING, Annotated, Any, cast, get_type_hints +import warnings from pydantic import ( BaseModel, @@ -36,6 +37,52 @@ def _deserialize_schema(v: Any) -> type[BaseModel] | None: return None +def _infer_result_schema_from_callable( + func: Callable[..., Any], +) -> type[BaseModel] | None: + try: + return_annotation = get_type_hints(func).get("return", inspect.Signature.empty) + except Exception: + return_annotation = inspect.signature(func).return_annotation + + if isinstance(return_annotation, type) and issubclass(return_annotation, BaseModel): + return return_annotation + + return None + + +def _format_tool_output_for_agent(tool: Any, raw_result: Any) -> str: + original_tool = getattr(tool, "_original_tool", None) + if original_tool is not None: + return cast(str, original_tool.format_output_for_agent(raw_result)) + + result_schema = getattr(tool, "result_schema", None) + if not (isinstance(result_schema, type) and issubclass(result_schema, BaseModel)): + return str(raw_result) + + try: + validation_input = raw_result + if isinstance(raw_result, BaseModel) and not isinstance( + raw_result, result_schema + ): + validation_input = raw_result.model_dump() + + validated = result_schema.model_validate(validation_input) + return validated.model_dump_json() + except Exception as exc: + warnings.warn( + ( + f"Failed to validate or serialize output from tool " + f"'{getattr(tool, 'name', '')}' using result_schema " + f"'{result_schema.__name__}': {exc.__class__.__name__}. " + "Falling back to str(raw_result)." + ), + RuntimeWarning, + stacklevel=2, + ) + return str(raw_result) + + if TYPE_CHECKING: pass @@ -81,6 +128,11 @@ class CrewStructuredTool(BaseModel): BeforeValidator(_deserialize_schema), PlainSerializer(_serialize_schema), ] = Field(default=None) + result_schema: Annotated[ + type[BaseModel] | None, + BeforeValidator(_deserialize_schema), + PlainSerializer(_serialize_schema), + ] = Field(default=None) func: Any = Field(default=None, exclude=True) result_as_answer: bool = Field(default=False) max_usage_count: int | None = Field(default=None) @@ -103,6 +155,7 @@ class CrewStructuredTool(BaseModel): description: str | None = None, return_direct: bool = False, args_schema: type[BaseModel] | None = None, + result_schema: type[BaseModel] | None = None, infer_schema: bool = True, **kwargs: Any, ) -> CrewStructuredTool: @@ -114,6 +167,7 @@ class CrewStructuredTool(BaseModel): description: The description of the tool. Defaults to the function docstring return_direct: Whether to return the output directly args_schema: Optional schema for the function arguments + result_schema: Optional schema for the function output infer_schema: Whether to infer the schema from the function signature **kwargs: Additional arguments to pass to the tool @@ -149,10 +203,16 @@ class CrewStructuredTool(BaseModel): name=name, description=description, args_schema=schema, + result_schema=result_schema or _infer_result_schema_from_callable(func), func=func, result_as_answer=return_direct, + **kwargs, ) + def format_output_for_agent(self, raw_result: Any) -> str: + """Format a raw tool result into the string representation sent to an agent.""" + return _format_tool_output_for_agent(self, raw_result) + @staticmethod def _create_schema_from_function( name: str, diff --git a/lib/crewai/src/crewai/tools/tool_usage.py b/lib/crewai/src/crewai/tools/tool_usage.py index b34921839..e92ba03ee 100644 --- a/lib/crewai/src/crewai/tools/tool_usage.py +++ b/lib/crewai/src/crewai/tools/tool_usage.py @@ -62,6 +62,9 @@ OPENAI_BIGGER_MODELS: list[ ] +_RAW_RESULT_UNSET = object() + + class ToolUsageError(Exception): """Exception raised for errors in the tool usage.""" @@ -106,6 +109,7 @@ class ToolUsage: self.action = action self.function_calling_llm = function_calling_llm self.fingerprint_context = fingerprint_context or {} + self.last_raw_result: Any = _RAW_RESULT_UNSET if ( self.function_calling_llm @@ -120,6 +124,11 @@ class ToolUsage: """Parse the tool string and return the tool calling.""" return self._tool_calling(tool_string) + def get_last_raw_result(self, fallback: Any) -> Any: + if self.last_raw_result is _RAW_RESULT_UNSET: + return fallback + return self.last_raw_result + def use( self, calling: ToolCalling | InstructorToolCalling, tool_string: str ) -> str: @@ -231,6 +240,7 @@ class ToolUsage: result = I18N_DEFAULT.errors("task_repeated_usage").format( tool_names=self.tools_names ) + self.last_raw_result = result self._telemetry.tool_repeated_usage( llm=self.function_calling_llm, tool_name=sanitize_tool_name(tool.name), @@ -298,6 +308,7 @@ class ToolUsage: ) if usage_limit_error: result = usage_limit_error + self.last_raw_result = result self._telemetry.tool_usage_error(llm=self.function_calling_llm) result = self._format_result(result=result) elif result is None: @@ -359,7 +370,10 @@ class ToolUsage: tool_name=sanitize_tool_name(tool.name), attempts=self._run_attempts, ) - result = self._format_result(result=result) + self.last_raw_result = result + result = self._format_result( + result=tool.format_output_for_agent(result) + ) data = { "result": result, "tool_name": sanitize_tool_name(tool.name), @@ -421,6 +435,7 @@ class ToolUsage: result = ToolUsageError( f"\n{error_message}.\nMoving on then. {I18N_DEFAULT.slice('format').format(tool_names=self.tools_names)}" ).message + self.last_raw_result = result if self.task: self.task.increment_tools_errors() if self.agent and self.agent.verbose: @@ -430,7 +445,10 @@ class ToolUsage: self.task.increment_tools_errors() should_retry = True else: - result = self._format_result(result=result) + self.last_raw_result = result + result = self._format_result( + result=tool.format_output_for_agent(result) + ) finally: if started_event_emitted and not error_event_emitted: @@ -460,6 +478,7 @@ class ToolUsage: result = I18N_DEFAULT.errors("task_repeated_usage").format( tool_names=self.tools_names ) + self.last_raw_result = result self._telemetry.tool_repeated_usage( llm=self.function_calling_llm, tool_name=sanitize_tool_name(tool.name), @@ -529,6 +548,7 @@ class ToolUsage: ) if usage_limit_error: result = usage_limit_error + self.last_raw_result = result self._telemetry.tool_usage_error(llm=self.function_calling_llm) result = self._format_result(result=result) elif result is None: @@ -590,7 +610,10 @@ class ToolUsage: tool_name=sanitize_tool_name(tool.name), attempts=self._run_attempts, ) - result = self._format_result(result=result) + self.last_raw_result = result + result = self._format_result( + result=tool.format_output_for_agent(result) + ) data = { "result": result, "tool_name": sanitize_tool_name(tool.name), @@ -652,6 +675,7 @@ class ToolUsage: result = ToolUsageError( f"\n{error_message}.\nMoving on then. {I18N_DEFAULT.slice('format').format(tool_names=self.tools_names)}" ).message + self.last_raw_result = result if self.task: self.task.increment_tools_errors() if self.agent and self.agent.verbose: @@ -661,7 +685,10 @@ class ToolUsage: self.task.increment_tools_errors() should_retry = True else: - result = self._format_result(result=result) + self.last_raw_result = result + result = self._format_result( + result=tool.format_output_for_agent(result) + ) finally: if started_event_emitted and not error_event_emitted: diff --git a/lib/crewai/src/crewai/utilities/agent_utils.py b/lib/crewai/src/crewai/utilities/agent_utils.py index 80f8ab242..e933a38a8 100644 --- a/lib/crewai/src/crewai/utilities/agent_utils.py +++ b/lib/crewai/src/crewai/utilities/agent_utils.py @@ -1383,6 +1383,19 @@ class NativeToolCallResult: tool_message: LLMMessage = field(default_factory=dict) # type: ignore[assignment] +def format_native_tool_output_for_agent(tool: Any, raw_result: Any) -> str: + """Format native tool output when a tool explicitly defines a formatter.""" + formatter = inspect.getattr_static(tool, "format_output_for_agent", None) + if formatter is None: + return str(raw_result) + + runtime_formatter = getattr(tool, "format_output_for_agent", None) + if not callable(runtime_formatter): + return str(raw_result) + + return str(runtime_formatter(raw_result)) + + def execute_single_native_tool_call( tool_call: Any, *, @@ -1456,18 +1469,24 @@ def execute_single_native_tool_call( original_tool = tool break + structured_tool: CrewStructuredTool | None = None + for structured in structured_tools or []: + if sanitize_tool_name(structured.name) == func_name: + structured_tool = structured + break + + output_tool = original_tool or structured_tool + from_cache = False input_str = json.dumps(args_dict) if args_dict else "" result = "Tool not found" + raw_tool_result: Any = result - if tools_handler and tools_handler.cache: + if tools_handler and tools_handler.cache and output_tool is not None: cached_result = tools_handler.cache.read(tool=func_name, input=input_str) if cached_result is not None: - result = ( - str(cached_result) - if not isinstance(cached_result, str) - else cached_result - ) + raw_tool_result = cached_result + result = format_native_tool_output_for_agent(output_tool, cached_result) from_cache = True started_at = datetime.now() @@ -1486,12 +1505,6 @@ def execute_single_native_tool_call( track_delegation_if_needed(func_name, args_dict, task) - structured_tool: CrewStructuredTool | None = None - for structured in structured_tools or []: - if sanitize_tool_name(structured.name) == func_name: - structured_tool = structured - break - hook_blocked = False before_hook_context = ToolCallHookContext( tool_name=func_name, @@ -1512,11 +1525,13 @@ def execute_single_native_tool_call( error_event_emitted = False if hook_blocked: result = f"Tool execution blocked by hook. Tool: {func_name}" + raw_tool_result = result elif not from_cache: - if func_name in available_functions: + if func_name in available_functions and output_tool is not None: try: tool_func = available_functions[func_name] raw_result = tool_func(**args_dict) + raw_tool_result = raw_result if tools_handler and tools_handler.cache: should_cache = True @@ -1529,11 +1544,10 @@ def execute_single_native_tool_call( tool=func_name, input=input_str, output=raw_result ) - result = ( - str(raw_result) if not isinstance(raw_result, str) else raw_result - ) + result = format_native_tool_output_for_agent(output_tool, raw_result) except Exception as e: result = f"Error executing tool: {e}" + raw_tool_result = result if task: task.increment_tools_errors() crewai_event_bus.emit( @@ -1559,6 +1573,7 @@ def execute_single_native_tool_call( task=task, crew=crew, tool_result=result, + raw_tool_result=raw_tool_result, ) try: for after_hook in get_after_tool_call_hooks(): diff --git a/lib/crewai/src/crewai/utilities/tool_utils.py b/lib/crewai/src/crewai/utilities/tool_utils.py index e19c3c81a..dcb25594c 100644 --- a/lib/crewai/src/crewai/utilities/tool_utils.py +++ b/lib/crewai/src/crewai/utilities/tool_utils.py @@ -116,6 +116,7 @@ async def aexecute_tool_and_check_finality( logger.log("error", f"Error in before_tool_call hook: {e}") tool_result = await tool_usage.ause(tool_calling, agent_action.text) + raw_tool_result = tool_usage.get_last_raw_result(tool_result) after_hook_context = ToolCallHookContext( tool_name=sanitized_tool_name, @@ -125,6 +126,7 @@ async def aexecute_tool_and_check_finality( task=task, crew=crew, tool_result=tool_result, + raw_tool_result=raw_tool_result, ) after_hooks = get_after_tool_call_hooks() @@ -234,6 +236,7 @@ def execute_tool_and_check_finality( logger.log("error", f"Error in before_tool_call hook: {e}") tool_result = tool_usage.use(tool_calling, agent_action.text) + raw_tool_result = tool_usage.get_last_raw_result(tool_result) after_hook_context = ToolCallHookContext( tool_name=sanitized_tool_name, @@ -243,6 +246,7 @@ def execute_tool_and_check_finality( task=task, crew=crew, tool_result=tool_result, + raw_tool_result=raw_tool_result, ) after_hooks = get_after_tool_call_hooks() diff --git a/lib/crewai/tests/agents/test_native_tool_calling.py b/lib/crewai/tests/agents/test_native_tool_calling.py index b7e0df199..894c0bd45 100644 --- a/lib/crewai/tests/agents/test_native_tool_calling.py +++ b/lib/crewai/tests/agents/test_native_tool_calling.py @@ -7,6 +7,7 @@ when the LLM supports it, across multiple providers. from __future__ import annotations from collections.abc import Generator +import json import os import threading import time @@ -20,7 +21,7 @@ from crewai import Agent, Crew, Task from crewai.agents.parser import AgentFinish from crewai.events import crewai_event_bus from crewai.hooks import register_after_tool_call_hook, register_before_tool_call_hook -from crewai.hooks.tool_hooks import ToolCallHookContext +from crewai.hooks.tool_hooks import ToolCallHookContext, clear_after_tool_call_hooks from crewai.llm import LLM from crewai.tools.base_tool import BaseTool @@ -1197,6 +1198,76 @@ class TestNativeToolCallingJsonParseError: assert result["result"] == "ran: print(1)" + def test_typed_output_is_json_agent_text(self) -> None: + class SearchOutput(BaseModel): + query: str + score: float + + class TypedSearchTool(BaseTool): + name: str = "typed_search" + description: str = "Search for information" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.8) + + tool = TypedSearchTool() + executor = self._make_executor([tool]) + + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + + _, available_functions, _ = convert_tools_to_openai_schema([tool]) + + result = executor._execute_single_native_tool_call( + call_id="call_typed", + func_name="typed_search", + func_args='{"query": "crew"}', + available_functions=available_functions, + ) + + assert json.loads(result["result"]) == {"query": "crew", "score": 0.8} + + def test_typed_output_after_hook_includes_raw_tool_result(self) -> None: + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + + class SearchOutput(BaseModel): + query: str + score: float + + class TypedSearchTool(BaseTool): + name: str = "typed_search" + description: str = "Search for information" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.8) + + seen_results: list[tuple[str | None, object]] = [] + + def after_hook(context: ToolCallHookContext) -> None: + seen_results.append((context.tool_result, context.raw_tool_result)) + + tool = TypedSearchTool() + executor = self._make_executor([tool]) + _, available_functions, _ = convert_tools_to_openai_schema([tool]) + + clear_after_tool_call_hooks() + register_after_tool_call_hook(after_hook) + try: + result = executor._execute_single_native_tool_call( + call_id="call_typed", + func_name="typed_search", + func_args='{"query": "crew"}', + available_functions=available_functions, + ) + finally: + clear_after_tool_call_hooks() + + assert json.loads(result["result"]) == {"query": "crew", "score": 0.8} + assert seen_results == [ + ('{"query":"crew","score":0.8}', SearchOutput(query="crew", score=0.8)) + ] + def test_native_tool_loop_falls_back_when_provider_rejects_tools(self) -> None: """Unsupported native tools errors should continue through ReAct.""" diff --git a/lib/crewai/tests/hooks/test_tool_hooks.py b/lib/crewai/tests/hooks/test_tool_hooks.py index 347eb56e5..14c6848a8 100644 --- a/lib/crewai/tests/hooks/test_tool_hooks.py +++ b/lib/crewai/tests/hooks/test_tool_hooks.py @@ -91,20 +91,24 @@ class TestToolCallHookContext: assert context.task == mock_task assert context.crew == mock_crew assert context.tool_result is None + assert context.raw_tool_result is None def test_context_with_result(self, mock_tool): """Test that context includes result when provided.""" tool_input = {"arg1": "value1"} tool_result = "Test tool result" + raw_tool_result = {"value": 42} context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, tool_result=tool_result, + raw_tool_result=raw_tool_result, ) assert context.tool_result == tool_result + assert context.raw_tool_result == raw_tool_result def test_tool_input_is_mutable_reference(self, mock_tool): """Test that modifying context.tool_input modifies the original dict.""" diff --git a/lib/crewai/tests/tools/test_base_tool.py b/lib/crewai/tests/tools/test_base_tool.py index 7648ad73b..52661fffc 100644 --- a/lib/crewai/tests/tools/test_base_tool.py +++ b/lib/crewai/tests/tools/test_base_tool.py @@ -1,12 +1,13 @@ import asyncio from collections.abc import Callable +import json from unittest.mock import patch from crewai.agent import Agent from crewai.crew import Crew from crewai.task import Task from crewai.tools import BaseTool, tool -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, RootModel import pytest @@ -351,6 +352,262 @@ class TestToolDecoratorRunValidation: assert result == "Hello, World!" +class SearchOutput(BaseModel): + query: str + score: float + + +class SearchResults(RootModel[list[SearchOutput]]): + pass + + +class ExplicitSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> dict[str, object]: + return {"query": query, "score": 0.8} + + +class InferredSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.7) + + +class RootSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + + def _run(self, query: str) -> SearchResults: + return SearchResults([SearchOutput(query=query, score=1.0)]) + + +class DictAnnotatedSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + + def _run(self, query: str) -> dict[str, object]: + return {"query": query, "score": 0.5} + + +def _make_explicit_decorator_tool() -> BaseTool: + @tool("search", result_schema=SearchOutput) + def search(query: str) -> dict[str, object]: + """Search for a query.""" + return {"query": query, "score": 0.8} + + return search + + +def _make_inferred_decorator_tool() -> BaseTool: + @tool("search") + def search(query: str) -> SearchOutput: + """Search for a query.""" + return SearchOutput(query=query, score=0.6) + + return search + + +def _make_root_decorator_tool() -> BaseTool: + @tool("search") + def search(query: str) -> SearchResults: + """Search for a query.""" + return SearchResults([SearchOutput(query=query, score=1.0)]) + + return search + + +class TestToolOutputSchema: + @pytest.mark.parametrize( + ("tool_cls", "expected_raw", "expected_agent_payload"), + [ + pytest.param( + ExplicitSearchTool, + {"query": "crew", "score": 0.8}, + {"query": "crew", "score": 0.8}, + id="explicit-schema", + ), + pytest.param( + InferredSearchTool, + SearchOutput(query="crew", score=0.7), + {"query": "crew", "score": 0.7}, + id="inferred-base-model", + ), + pytest.param( + RootSearchTool, + SearchResults([SearchOutput(query="crew", score=1.0)]), + [{"query": "crew", "score": 1.0}], + id="inferred-root-model", + ), + ], + ) + def test_base_tools_return_raw_result_and_json_agent_text( + self, + tool_cls: type[BaseTool], + expected_raw: object, + expected_agent_payload: object, + ) -> None: + t = tool_cls() + + raw_result = t.run(query="crew") + + assert raw_result == expected_raw + assert json.loads(t.format_output_for_agent(raw_result)) == ( + expected_agent_payload + ) + + def test_base_tool_does_not_infer_non_pydantic_return_annotation(self) -> None: + t = DictAnnotatedSearchTool() + + raw_result = t.run(query="crew") + + assert raw_result == {"query": "crew", "score": 0.5} + assert t.format_output_for_agent(raw_result) == str(raw_result) + + @pytest.mark.parametrize( + ("make_tool", "expected_raw", "expected_agent_payload"), + [ + pytest.param( + _make_explicit_decorator_tool, + {"query": "crew", "score": 0.8}, + {"query": "crew", "score": 0.8}, + id="explicit-schema", + ), + pytest.param( + _make_inferred_decorator_tool, + SearchOutput(query="crew", score=0.6), + {"query": "crew", "score": 0.6}, + id="inferred-base-model", + ), + pytest.param( + _make_root_decorator_tool, + SearchResults([SearchOutput(query="crew", score=1.0)]), + [{"query": "crew", "score": 1.0}], + id="inferred-root-model", + ), + ], + ) + def test_decorator_tools_return_raw_result_and_json_agent_text( + self, + make_tool: Callable[[], BaseTool], + expected_raw: object, + expected_agent_payload: object, + ) -> None: + search = make_tool() + + raw_result = search.run(query="crew") + + assert raw_result == expected_raw + assert json.loads(search.format_output_for_agent(raw_result)) == ( + expected_agent_payload + ) + + def test_decorator_tool_does_not_infer_non_pydantic_return_annotation( + self, + ) -> None: + @tool("search") + def search(query: str) -> dict[str, object]: + """Search for a query.""" + return {"query": query, "score": 0.5} + + raw_result = search.run(query="crew") + + assert raw_result == {"query": "crew", "score": 0.5} + assert search.format_output_for_agent(raw_result) == str(raw_result) + + def test_explicit_result_schema_wins_over_return_annotation(self) -> None: + class AlternateOutput(BaseModel): + value: str + + @tool("search", result_schema=AlternateOutput) + def search(query: str) -> SearchOutput: + """Search for a query.""" + return SearchOutput(query=query, score=0.6) + + raw_result = search.run(query="crew") + + with pytest.warns(RuntimeWarning, match="AlternateOutput"): + agent_text = search.format_output_for_agent(raw_result) + + assert raw_result == SearchOutput(query="crew", score=0.6) + assert agent_text == str(raw_result) + + def test_invalid_typed_output_warns_and_uses_string_agent_text( + self, + ) -> None: + @tool("search", result_schema=SearchOutput) + def search(query: str) -> dict[str, object]: + """Search for a query.""" + return {"query": query, "score": "not-a-float"} + + raw_result = search.run(query="crew") + + with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"): + agent_text = search.format_output_for_agent(raw_result) + + assert raw_result == {"query": "crew", "score": "not-a-float"} + assert agent_text == str(raw_result) + + def test_unserializable_typed_output_warns_and_uses_string_agent_text( + self, + ) -> None: + class OpaqueOutput(BaseModel): + value: object + + raw_result = OpaqueOutput(value=object()) + + @tool("opaque", result_schema=OpaqueOutput) + def opaque() -> OpaqueOutput: + """Return an opaque object.""" + return raw_result + + result = opaque.run() + + with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"): + agent_text = opaque.format_output_for_agent(result) + + assert result is raw_result + assert agent_text == str(raw_result) + + def test_result_schema_behavior_carries_over_to_structured_tool(self) -> None: + structured = ExplicitSearchTool().to_structured_tool() + + raw_result = structured.invoke({"query": "crew"}) + + assert raw_result == {"query": "crew", "score": 0.8} + assert json.loads(structured.format_output_for_agent(raw_result)) == { + "query": "crew", + "score": 0.8, + } + + def test_custom_agent_output_formatter_carries_over_to_structured_tool( + self, + ) -> None: + class MarkdownSearchTool(BaseTool): + name: str = "markdown_search" + description: str = "Search for information" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.8) + + def format_output_for_agent(self, raw_result: object) -> str: + result = self.result_schema.model_validate(raw_result) + return f"### Search result\n\n- Query: `{result.query}`\n- Score: {result.score}" + + structured = MarkdownSearchTool().to_structured_tool() + + raw_result = structured.invoke({"query": "crew"}) + + assert raw_result == SearchOutput(query="crew", score=0.8) + assert structured.format_output_for_agent(raw_result) == ( + "### Search result\n\n- Query: `crew`\n- Score: 0.8" + ) + # Async arun() Schema Validation Tests diff --git a/lib/crewai/tests/tools/test_structured_tool.py b/lib/crewai/tests/tools/test_structured_tool.py index 27c463d47..0241abbcf 100644 --- a/lib/crewai/tests/tools/test_structured_tool.py +++ b/lib/crewai/tests/tools/test_structured_tool.py @@ -1,5 +1,7 @@ +import json + from crewai.tools.structured_tool import CrewStructuredTool -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, RootModel import pytest @@ -86,6 +88,118 @@ def test_from_function(basic_function): assert isinstance(tool.args_schema, type(BaseModel)) +class StructuredOutput(BaseModel): + value: str + count: int + + +class StructuredOutputList(RootModel[list[StructuredOutput]]): + pass + + +def _build_explicit_structured_value(value: str) -> dict[str, object]: + """Build a value.""" + return {"value": value, "count": 1} + + +def _build_inferred_structured_value(value: str) -> StructuredOutput: + """Build a value.""" + return StructuredOutput(value=value, count=1) + + +def _build_structured_values(value: str) -> StructuredOutputList: + """Build values.""" + return StructuredOutputList([StructuredOutput(value=value, count=1)]) + + +def _build_plain_structured_value(value: str) -> dict[str, object]: + """Build a value.""" + return {"value": value, "count": 1} + + +@pytest.mark.parametrize( + ("func", "result_schema", "expected_raw", "expected_agent_payload"), + [ + pytest.param( + _build_explicit_structured_value, + StructuredOutput, + {"value": "crew", "count": 1}, + {"value": "crew", "count": 1}, + id="explicit-schema", + ), + pytest.param( + _build_inferred_structured_value, + None, + StructuredOutput(value="crew", count=1), + {"value": "crew", "count": 1}, + id="inferred-base-model", + ), + pytest.param( + _build_structured_values, + None, + StructuredOutputList([StructuredOutput(value="crew", count=1)]), + [{"value": "crew", "count": 1}], + id="inferred-root-model", + ), + ], +) +def test_from_function_returns_raw_result_and_json_agent_text( + func, + result_schema, + expected_raw, + expected_agent_payload, +): + kwargs = {"result_schema": result_schema} if result_schema is not None else {} + tool = CrewStructuredTool.from_function( + func=func, + name="build_value", + **kwargs, + ) + + raw_result = tool.invoke({"value": "crew"}) + + assert raw_result == expected_raw + assert json.loads(tool.format_output_for_agent(raw_result)) == ( + expected_agent_payload + ) + + +def test_from_function_does_not_infer_non_pydantic_result_schema(): + tool = CrewStructuredTool.from_function( + func=_build_plain_structured_value, + name="build_value", + ) + + raw_result = tool.invoke({"value": "crew"}) + + assert raw_result == {"value": "crew", "count": 1} + assert tool.format_output_for_agent(raw_result) == str(raw_result) + + +def test_invalid_typed_output_warns_and_uses_string_agent_text(): + def build_value(value: str) -> dict[str, object]: + """Build a value.""" + return {"value": value, "count": "wrong"} + + tool = CrewStructuredTool.from_function( + func=build_value, + name="build_value", + result_schema=StructuredOutput, + ) + raw_result = tool.invoke({"value": "crew"}) + + with pytest.warns( + RuntimeWarning, match="Failed to validate or serialize" + ) as warnings: + agent_text = tool.format_output_for_agent(raw_result) + + assert raw_result == {"value": "crew", "count": "wrong"} + assert agent_text == str(raw_result) + warning_message = str(warnings[0].message) + assert "ValidationError" in warning_message + assert "wrong" not in warning_message + + def test_validate_function_signature(basic_function, schema_class): """Test function signature validation""" tool = CrewStructuredTool( diff --git a/lib/crewai/tests/tools/test_tool_usage.py b/lib/crewai/tests/tools/test_tool_usage.py index ba4fe72dd..9d61c93a9 100644 --- a/lib/crewai/tests/tools/test_tool_usage.py +++ b/lib/crewai/tests/tools/test_tool_usage.py @@ -1,4 +1,5 @@ import datetime +from collections.abc import Callable import json import random import threading @@ -6,6 +7,9 @@ import time from unittest.mock import MagicMock, patch from crewai import Agent, Task +from crewai.agents.cache.cache_handler import CacheHandler +from crewai.agents.parser import AgentAction +from crewai.agents.tools_handler import ToolsHandler from crewai.events.event_bus import crewai_event_bus from crewai.events.types.tool_usage_events import ( ToolSelectionErrorEvent, @@ -14,8 +18,15 @@ from crewai.events.types.tool_usage_events import ( ToolUsageStartedEvent, ToolValidateInputErrorEvent, ) +from crewai.hooks.tool_hooks import ( + ToolCallHookContext, + clear_after_tool_call_hooks, + register_after_tool_call_hook, +) from crewai.tools import BaseTool +from crewai.tools.tool_calling import ToolCalling from crewai.tools.tool_usage import ToolUsage +from crewai.utilities.tool_utils import execute_tool_and_check_finality from pydantic import BaseModel, Field import pytest @@ -38,6 +49,19 @@ class RandomNumberTool(BaseTool): return random.randint(min_value, max_value) # noqa: S311 +class SearchOutput(BaseModel): + query: str + score: float + + +class TypedSearchTool(BaseTool): + name: str = "typed_search" + description: str = "Search for a query" + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.7) + + # Example agent and task example_agent = Agent( role="Number Generator", @@ -117,6 +141,126 @@ def test_tool_usage_render(): assert '"description": "The maximum value of the range (inclusive)"' in rendered +def test_tool_usage_returns_json_agent_text_for_typed_output(): + tool = TypedSearchTool().to_structured_tool() + tool_usage = ToolUsage( + tools_handler=None, + tools=[tool], + task=None, + function_calling_llm=MagicMock(), + agent=None, + action=MagicMock(), + ) + + result = tool_usage.use( + calling=ToolCalling( + tool_name="typed_search", + arguments={"query": "crew"}, + ), + tool_string='Action: typed_search\nAction Input: {"query": "crew"}', + ) + + assert json.loads(result) == {"query": "crew", "score": 0.7} + + +def test_tool_usage_cache_callback_receives_raw_typed_output(): + raw_results: list[object] = [] + + def cache_result(_args: object, result: object) -> bool: + raw_results.append(result) + return True + + class CacheAwareTypedSearchTool(TypedSearchTool): + cache_function: Callable = cache_result + + tools_handler = MagicMock() + tools_handler.cache = None + tools_handler.last_used_tool = None + tool = CacheAwareTypedSearchTool().to_structured_tool() + tool_usage = ToolUsage( + tools_handler=tools_handler, + tools=[tool], + task=None, + function_calling_llm=MagicMock(), + agent=None, + action=MagicMock(), + ) + + result = tool_usage.use( + calling=ToolCalling( + tool_name="typed_search", + arguments={"query": "crew"}, + ), + tool_string='Action: typed_search\nAction Input: {"query": "crew"}', + ) + + assert json.loads(result) == {"query": "crew", "score": 0.7} + assert raw_results == [SearchOutput(query="crew", score=0.7)] + tools_handler.on_tool_use.assert_called_once() + assert tools_handler.on_tool_use.call_args.kwargs["output"] == SearchOutput( + query="crew", + score=0.7, + ) + + +def test_react_tool_hooks_receive_agent_text_and_raw_cached_typed_output(): + structured_tool = TypedSearchTool().to_structured_tool() + tools_handler = ToolsHandler(cache=CacheHandler()) + seen_results: list[tuple[str | None, object]] = [] + + def after_hook(context: ToolCallHookContext) -> None: + seen_results.append((context.tool_result, context.raw_tool_result)) + + clear_after_tool_call_hooks() + register_after_tool_call_hook(after_hook) + + action = AgentAction( + thought="", + tool="typed_search", + tool_input='{"query": "crew"}', + text='Action: typed_search\nAction Input: {"query": "crew"}', + ) + + try: + first = execute_tool_and_check_finality( + agent_action=action, + tools=[structured_tool], + tools_handler=tools_handler, + ) + tools_handler.last_used_tool = None + second = execute_tool_and_check_finality( + agent_action=action, + tools=[structured_tool], + tools_handler=tools_handler, + ) + finally: + clear_after_tool_call_hooks() + + assert json.loads(first.result) == {"query": "crew", "score": 0.7} + assert json.loads(second.result) == {"query": "crew", "score": 0.7} + assert seen_results == [ + ('{"query":"crew","score":0.7}', SearchOutput(query="crew", score=0.7)), + ('{"query":"crew","score":0.7}', SearchOutput(query="crew", score=0.7)), + ] + + +def test_last_raw_result_falls_back_only_until_recorded(): + tool_usage = ToolUsage( + tools_handler=None, + tools=[], + task=None, + function_calling_llm=MagicMock(), + agent=None, + action=MagicMock(), + ) + + assert tool_usage.get_last_raw_result("formatted result") == "formatted result" + + tool_usage.last_raw_result = None + + assert tool_usage.get_last_raw_result("formatted result") is None + + def test_validate_tool_input_booleans_and_none(): tool_usage = ToolUsage( tools_handler=MagicMock(), diff --git a/lib/crewai/tests/utilities/test_agent_utils.py b/lib/crewai/tests/utilities/test_agent_utils.py index de3ed411b..9cf4a2d2a 100644 --- a/lib/crewai/tests/utilities/test_agent_utils.py +++ b/lib/crewai/tests/utilities/test_agent_utils.py @@ -3,12 +3,19 @@ from __future__ import annotations import asyncio +import json from typing import Any, Literal, Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import BaseModel, Field +from crewai.hooks.tool_hooks import ( + ToolCallHookContext, + clear_after_tool_call_hooks, + clear_before_tool_call_hooks, + register_after_tool_call_hook, +) from crewai.tools.base_tool import BaseTool from crewai.utilities.agent_utils import ( _asummarize_chunks, @@ -1030,6 +1037,142 @@ class TestParseToolCallArgs: class TestExecuteSingleNativeToolCall: """Tests for execute_single_native_tool_call.""" + def test_typed_tool_output_is_json_agent_text(self) -> None: + clear_before_tool_call_hooks() + clear_after_tool_call_hooks() + + class SearchOutput(BaseModel): + query: str + score: float + + class TypedSearchTool(BaseTool): + name: str = "typed_search" + description: str = "Search for a query" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.9) + + tool = TypedSearchTool() + tool_call = MagicMock() + tool_call.id = "call_1" + tool_call.function.name = "typed_search" + tool_call.function.arguments = '{"query": "crew"}' + + result = execute_single_native_tool_call( + tool_call, + available_functions={"typed_search": tool._run}, + original_tools=[tool], + structured_tools=[tool.to_structured_tool()], + tools_handler=None, + agent=None, + task=None, + crew=None, + event_source=MagicMock(), + printer=None, + verbose=False, + ) + + assert json.loads(result.result) == {"query": "crew", "score": 0.9} + assert json.loads(result.tool_message["content"]) == { + "query": "crew", + "score": 0.9, + } + + def test_custom_agent_output_formatter_is_used_from_structured_tool( + self, + ) -> None: + clear_before_tool_call_hooks() + clear_after_tool_call_hooks() + + class SearchOutput(BaseModel): + query: str + score: float + + class MarkdownSearchTool(BaseTool): + name: str = "markdown_search" + description: str = "Search for a query" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.9) + + def format_output_for_agent(self, raw_result: Any) -> str: + result = self.result_schema.model_validate(raw_result) + return f"### {result.query}\n\nScore: **{result.score}**" + + tool = MarkdownSearchTool() + tool_call = MagicMock() + tool_call.id = "call_1" + tool_call.function.name = "markdown_search" + tool_call.function.arguments = '{"query": "crew"}' + + result = execute_single_native_tool_call( + tool_call, + available_functions={"markdown_search": tool._run}, + original_tools=[], + structured_tools=[tool.to_structured_tool()], + tools_handler=None, + agent=None, + task=None, + crew=None, + event_source=MagicMock(), + printer=None, + verbose=False, + ) + + assert result.result == "### crew\n\nScore: **0.9**" + assert result.tool_message["content"] == "### crew\n\nScore: **0.9**" + + def test_after_hook_includes_raw_tool_result_for_typed_output(self) -> None: + clear_after_tool_call_hooks() + + class SearchOutput(BaseModel): + query: str + score: float + + class TypedSearchTool(BaseTool): + name: str = "typed_search" + description: str = "Search for a query" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.9) + + seen_results: list[tuple[str | None, object]] = [] + + def after_hook(context: ToolCallHookContext) -> None: + seen_results.append((context.tool_result, context.raw_tool_result)) + + tool = TypedSearchTool() + tool_call = MagicMock() + tool_call.id = "call_1" + tool_call.function.name = "typed_search" + tool_call.function.arguments = '{"query": "crew"}' + + register_after_tool_call_hook(after_hook) + try: + result = execute_single_native_tool_call( + tool_call, + available_functions={"typed_search": tool._run}, + original_tools=[tool], + structured_tools=[tool.to_structured_tool()], + tools_handler=None, + agent=None, + task=None, + crew=None, + event_source=MagicMock(), + printer=None, + verbose=False, + ) + finally: + clear_after_tool_call_hooks() + + assert json.loads(result.result) == {"query": "crew", "score": 0.9} + assert seen_results == [ + ('{"query":"crew","score":0.9}', SearchOutput(query="crew", score=0.9)) + ] + def test_result_as_answer_false_on_tool_error(self) -> None: """When a tool with result_as_answer=True raises, result_as_answer must be False.