diff --git a/docs/edge/en/concepts/tools.mdx b/docs/edge/en/concepts/tools.mdx index 52e568073..da41da3b1 100644 --- a/docs/edge/en/concepts/tools.mdx +++ b/docs/edge/en/concepts/tools.mdx @@ -39,6 +39,7 @@ The Enterprise Tools Repository includes: - **Error Handling**: Incorporates robust error handling mechanisms to ensure smooth operation. - **Caching Mechanism**: Features intelligent caching to optimize performance and reduce redundant operations. - **Asynchronous Support**: Handles both synchronous and asynchronous tools, enabling non-blocking operations. +- **Typed Outputs**: Uses optional Pydantic models to give agents clear JSON fields while direct Python calls still receive the tool's normal return value. ## Using CrewAI Tools @@ -184,6 +185,55 @@ class MyCustomTool(BaseTool): return "Tool's result" ``` +### Typed Tool Outputs + +When a tool returns structured data, define a Pydantic output model. This gives the agent field names it can trust, such as `sku`, `quantity`, or `needs_reorder`. + +Direct Python calls still receive the value your tool returns. When an agent uses the tool, CrewAI sends the agent a JSON string based on the output model. + +```python Code +from crewai.tools import BaseTool +from pydantic import BaseModel + +class InventoryResult(BaseModel): + sku: str + quantity: int + needs_reorder: bool + +class InventoryTool(BaseTool): + name: str = "Inventory Check" + description: str = "Checks current stock for a product SKU." + + def _run(self, sku: str) -> InventoryResult: + quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0) + return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5) + +tool = InventoryTool() + +# Direct calls receive the raw Pydantic object. +result = tool.run(sku="SKU-123") +print(result.quantity) +``` + +To send Markdown or another short text format to the agent, override `format_output_for_agent`. Direct calls to `tool.run(...)` still return the normal Python value. + +```python Code +class InventoryTool(BaseTool): + name: str = "Inventory Check" + description: str = "Checks current stock for a product SKU." + + def _run(self, sku: str) -> InventoryResult: + quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0) + return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5) + + def format_output_for_agent(self, raw_result: object) -> str: + result = InventoryResult.model_validate(raw_result) + status = "reorder needed" if result.needs_reorder else "stock is healthy" + return f"{result.sku}: {result.quantity} units. {status}." +``` + +If you do not override `format_output_for_agent`, typed outputs are sent to the agent as JSON. Plain string results work as before. + ## Asynchronous Tool Support CrewAI supports asynchronous tools, allowing you to implement tools that perform non-blocking operations like network requests, file I/O, or other async operations without blocking the main execution thread. diff --git a/docs/edge/en/guides/tools/publish-custom-tools.mdx b/docs/edge/en/guides/tools/publish-custom-tools.mdx index 973856816..71dbfd622 100644 --- a/docs/edge/en/guides/tools/publish-custom-tools.mdx +++ b/docs/edge/en/guides/tools/publish-custom-tools.mdx @@ -65,7 +65,7 @@ Regardless of which approach you use, your tool must: - Have a **`description`** — tells the agent when and how to use the tool. This directly affects how well agents use your tool, so be clear and specific. - Implement **`_run`** (BaseTool) or provide a **function body** (@tool) — the synchronous execution logic. - Use **type annotations** on all parameters and return values. -- Return a **string** result (or something that can be meaningfully converted to one). +- Return a **string** result, or define an optional Pydantic output schema for structured results. ### Optional: Async Support @@ -104,6 +104,67 @@ class TranslateInput(BaseModel): Explicit schemas are recommended for published tools — they produce better agent behavior and clearer documentation for your users. +### Optional: Typed Outputs with `result_schema` + +If your tool returns structured data, define a Pydantic output model. This is a good default for published tools because users and agents can rely on named fields. + +Direct Python calls still receive the value your tool returns. When an agent uses the tool, CrewAI sends the agent JSON based on the output model. + +CrewAI can infer the output schema from a Pydantic return annotation: + +```python +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + + +class GeolocateResult(BaseModel): + latitude: float = Field(..., description="Latitude in decimal degrees.") + longitude: float = Field(..., description="Longitude in decimal degrees.") + + +class GeolocateTool(BaseTool): + name: str = "Geolocate" + description: str = "Converts a street address into latitude/longitude coordinates." + + def _run(self, address: str) -> GeolocateResult: + if "1600 Pennsylvania" in address: + return GeolocateResult(latitude=38.8977, longitude=-77.0365) + return GeolocateResult(latitude=40.7128, longitude=-74.0060) +``` + +Set `result_schema` explicitly when your tool returns a dictionary: + +```python +class GeolocateTool(BaseTool): + name: str = "Geolocate" + description: str = "Converts a street address into latitude/longitude coordinates." + result_schema: type[BaseModel] = GeolocateResult + + def _run(self, address: str) -> dict[str, float]: + if "1600 Pennsylvania" in address: + return {"latitude": 38.8977, "longitude": -77.0365} + return {"latitude": 40.7128, "longitude": -74.0060} +``` + +If agents should receive a short text summary instead of JSON, override `format_output_for_agent` on your `BaseTool` subclass. + +```python +class GeolocateTool(BaseTool): + name: str = "Geolocate" + description: str = "Converts a street address into latitude/longitude coordinates." + + def _run(self, address: str) -> GeolocateResult: + if "1600 Pennsylvania" in address: + return GeolocateResult(latitude=38.8977, longitude=-77.0365) + return GeolocateResult(latitude=40.7128, longitude=-74.0060) + + def format_output_for_agent(self, raw_result: object) -> str: + result = GeolocateResult.model_validate(raw_result) + return f"Latitude {result.latitude}, longitude {result.longitude}" +``` + +The override only changes what the agent sees. Direct users of your package still receive the normal value from `tool.run(...)`. + ### Optional: Environment Variables If your tool requires API keys or other configuration, declare them with `env_vars` so users know what to set: @@ -241,4 +302,4 @@ agent = Agent( tools=[GeolocateTool()], # ... ) -``` \ No newline at end of file +``` diff --git a/docs/edge/en/learn/create-custom-tools.mdx b/docs/edge/en/learn/create-custom-tools.mdx index c1246f3fc..78205bf99 100644 --- a/docs/edge/en/learn/create-custom-tools.mdx +++ b/docs/edge/en/learn/create-custom-tools.mdx @@ -53,6 +53,111 @@ def my_simple_tool(question: str) -> str: return "Tool output" ``` +### Best Practice: Define Typed Outputs + +When a tool returns structured data, define a Pydantic output model. This helps the agent read the result as clear fields instead of guessing from plain text. + +Typed outputs are useful for results with stable fields, such as IDs, status values, scores, prices, or lists. Plain strings are still fine for short prose results. + +Direct Python calls still receive the value your tool returns. When an agent uses a typed tool, CrewAI sends the agent JSON based on the output model. + +#### Return a Pydantic Model + +CrewAI infers the output schema when your `BaseTool` has a Pydantic return annotation. + +```python Code +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + +class InventoryResult(BaseModel): + sku: str = Field(description="The product SKU.") + quantity: int = Field(description="Units available.") + needs_reorder: bool = Field(description="Whether the item should be reordered.") + +class InventoryTool(BaseTool): + name: str = "Inventory Check" + description: str = "Check current stock for a product SKU." + + def _run(self, sku: str) -> InventoryResult: + quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0) + return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5) + +tool = InventoryTool() +result = tool.run(sku="SKU-123") + +# Direct Python calls receive the raw Pydantic object. +print(result.quantity) +``` + +When an agent calls `InventoryTool`, it receives JSON like this: + +```json +{"sku":"SKU-123","quantity":14,"needs_reorder":false} +``` + +#### Use `result_schema` with Dictionary Results + +If your tool returns a dictionary, set `result_schema` explicitly. You can do this on a `BaseTool` subclass or with the `@tool` decorator: + +```python Code +from crewai.tools import tool +from pydantic import BaseModel, Field + +class ProductResult(BaseModel): + sku: str = Field(description="The product SKU.") + name: str = Field(description="The product name.") + in_stock: bool = Field(description="Whether the product is available.") + +@tool("Product Lookup", result_schema=ProductResult) +def product_lookup(sku: str) -> dict[str, object]: + """Look up product availability by SKU.""" + catalog = { + "SKU-123": ("Noise-canceling headset", True), + "SKU-456": ("USB-C dock", False), + } + name, in_stock = catalog.get(sku, ("Unknown product", False)) + return { + "sku": sku, + "name": name, + "in_stock": in_stock, + } +``` + +#### Customize the Text Sent to the Agent + +By default, typed tool outputs are sent to the agent as JSON. If the agent should receive a short summary instead, subclass `BaseTool` and override `format_output_for_agent`. + +```python Code +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + +class InventoryResult(BaseModel): + sku: str = Field(description="The product SKU.") + quantity: int = Field(description="Units available.") + needs_reorder: bool = Field(description="Whether the item should be reordered.") + +class InventoryTool(BaseTool): + name: str = "Inventory Check" + description: str = "Check current stock for a product SKU." + + def _run(self, sku: str) -> InventoryResult: + quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0) + return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5) + + def format_output_for_agent(self, raw_result: object) -> str: + result = InventoryResult.model_validate(raw_result) + status = "reorder needed" if result.needs_reorder else "stock is healthy" + return f"{result.sku}: {result.quantity} units. {status}." + +tool = InventoryTool() +result = tool.run(sku="SKU-123") + +# Direct Python calls receive the raw Pydantic object. +print(result.quantity) +``` + +The override only changes what the agent sees. Direct calls to `tool.run(...)` still return the normal Python value. + ### Defining a Cache Function for the Tool To optimize tool performance with caching, define custom caching strategies using the `cache_function` attribute. diff --git a/docs/edge/en/learn/execution-hooks.mdx b/docs/edge/en/learn/execution-hooks.mdx index 74234db97..f9f667e4f 100644 --- a/docs/edge/en/learn/execution-hooks.mdx +++ b/docs/edge/en/learn/execution-hooks.mdx @@ -195,9 +195,12 @@ class ToolCallHookContext: agent: Agent | None # Agent executing task: Task | None # Current task crew: Crew | None # Crew instance - tool_result: str | None # Tool result (after hooks) + tool_result: str | None # Agent-facing result string (after hooks) + raw_tool_result: Any | None # Raw Python result (after hooks) ``` +For typed tool outputs, `tool_result` is the string the agent sees. By default, this is JSON. If the tool uses custom formatting, it can be Markdown or another string. `raw_tool_result` is the original Python value returned by the tool. + ## Common Patterns ### Safety and Validation diff --git a/docs/edge/en/learn/tool-hooks.mdx b/docs/edge/en/learn/tool-hooks.mdx index d1d727a5c..489463e81 100644 --- a/docs/edge/en/learn/tool-hooks.mdx +++ b/docs/edge/en/learn/tool-hooks.mdx @@ -60,9 +60,12 @@ class ToolCallHookContext: agent: Agent | BaseAgent | None # Agent executing the tool task: Task | None # Current task crew: Crew | None # Crew instance - tool_result: str | None # Tool result (after hooks only) + tool_result: str | None # Agent-facing result string (after hooks only) + raw_tool_result: Any | None # Raw Python result (after hooks only) ``` +For typed tool outputs, `tool_result` is the string the agent sees. By default, this is JSON. If the tool uses custom formatting, it can be Markdown or another string. Use `raw_tool_result` when your hook needs the typed object or dictionary. + ### Modifying Tool Inputs **Important:** Always modify tool inputs in-place: diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index 92a1ce5fb..de2315e3a 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -57,6 +57,7 @@ from crewai.utilities.agent_utils import ( convert_tools_to_openai_schema, enforce_rpm_limit, format_message_for_llm, + format_native_tool_output_for_agent, get_llm_response, handle_agent_action_core, handle_context_length, @@ -907,19 +908,31 @@ class CrewAgentExecutor(BaseAgentExecutor): ): max_usage_reached = True + structured_tool: CrewStructuredTool | None = None + if original_tool is not None: + for structured in self.tools or []: + if getattr(structured, "_original_tool", None) is original_tool: + structured_tool = structured + break + if structured_tool is None: + for structured in self.tools or []: + if sanitize_tool_name(structured.name) == func_name: + structured_tool = structured + break + + output_tool = original_tool or structured_tool + from_cache = False result: str = "Tool not found" + raw_tool_result: Any = result input_str = json.dumps(args_dict) if args_dict else "" - if self.tools_handler and self.tools_handler.cache: + if self.tools_handler and self.tools_handler.cache and output_tool is not None: cached_result = self.tools_handler.cache.read( tool=func_name, input=input_str ) if cached_result is not None: - result = ( - str(cached_result) - if not isinstance(cached_result, str) - else cached_result - ) + raw_tool_result = cached_result + result = format_native_tool_output_for_agent(output_tool, cached_result) from_cache = True agent_key = getattr(self.agent, "key", "unknown") if self.agent else "unknown" @@ -938,18 +951,6 @@ class CrewAgentExecutor(BaseAgentExecutor): track_delegation_if_needed(func_name, args_dict or {}, self.task) - structured_tool: CrewStructuredTool | None = None - if original_tool is not None: - for structured in self.tools or []: - if getattr(structured, "_original_tool", None) is original_tool: - structured_tool = structured - break - if structured_tool is None: - for structured in self.tools or []: - if sanitize_tool_name(structured.name) == func_name: - structured_tool = structured - break - hook_blocked = False before_hook_context = ToolCallHookContext( tool_name=func_name, @@ -975,11 +976,18 @@ class CrewAgentExecutor(BaseAgentExecutor): if hook_blocked: result = f"Tool execution blocked by hook. Tool: {func_name}" + raw_tool_result = result elif max_usage_reached and original_tool: result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore." - elif not from_cache and func_name in available_functions: + raw_tool_result = result + elif ( + not from_cache + and func_name in available_functions + and output_tool is not None + ): try: raw_result = available_functions[func_name](**(args_dict or {})) + raw_tool_result = raw_result if self.tools_handler and self.tools_handler.cache: should_cache = True @@ -996,11 +1004,10 @@ class CrewAgentExecutor(BaseAgentExecutor): tool=func_name, input=input_str, output=raw_result ) - result = ( - str(raw_result) if not isinstance(raw_result, str) else raw_result - ) + result = format_native_tool_output_for_agent(output_tool, raw_result) except Exception as e: result = f"Error executing tool: {e}" + raw_tool_result = result if self.task: self.task.increment_tools_errors() crewai_event_bus.emit( @@ -1024,6 +1031,7 @@ class CrewAgentExecutor(BaseAgentExecutor): task=self.task, crew=self.crew, tool_result=result, + raw_tool_result=raw_tool_result, ) after_hooks = get_after_tool_call_hooks() try: diff --git a/lib/crewai/src/crewai/agents/tools_handler.py b/lib/crewai/src/crewai/agents/tools_handler.py index 8ab759b85..c56226bc1 100644 --- a/lib/crewai/src/crewai/agents/tools_handler.py +++ b/lib/crewai/src/crewai/agents/tools_handler.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +from typing import Any from pydantic import BaseModel, Field @@ -25,14 +26,14 @@ class ToolsHandler(BaseModel): def on_tool_use( self, calling: ToolCalling | InstructorToolCalling, - output: str, + output: Any, should_cache: bool = True, ) -> None: """Run when tool ends running. Args: calling: The tool calling instance. - output: The output from the tool execution. + output: The raw output from the tool execution. should_cache: Whether to cache the tool output. """ self.last_used_tool = calling diff --git a/lib/crewai/src/crewai/experimental/agent_executor.py b/lib/crewai/src/crewai/experimental/agent_executor.py index c026c7509..303330dc6 100644 --- a/lib/crewai/src/crewai/experimental/agent_executor.py +++ b/lib/crewai/src/crewai/experimental/agent_executor.py @@ -80,6 +80,7 @@ from crewai.utilities.agent_utils import ( enforce_rpm_limit, extract_tool_call_info, format_message_for_llm, + format_native_tool_output_for_agent, get_llm_response, handle_agent_action_core, handle_context_length, @@ -1905,19 +1906,32 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): ): max_usage_reached = True + structured_tool: CrewStructuredTool | None = None + if original_tool is not None: + for structured in self.tools or []: + if getattr(structured, "_original_tool", None) is original_tool: + structured_tool = structured + break + if structured_tool is None: + for structured in self.tools or []: + if sanitize_tool_name(structured.name) == func_name: + structured_tool = structured + break + + output_tool = original_tool or structured_tool + # Check cache before executing from_cache = False + result = "Tool not found" + raw_tool_result: Any = result input_str = json.dumps(args_dict) if args_dict else "" - if self.tools_handler and self.tools_handler.cache: + if self.tools_handler and self.tools_handler.cache and output_tool is not None: cached_result = self.tools_handler.cache.read( tool=func_name, input=input_str ) if cached_result is not None: - result = ( - str(cached_result) - if not isinstance(cached_result, str) - else cached_result - ) + raw_tool_result = cached_result + result = format_native_tool_output_for_agent(output_tool, cached_result) from_cache = True # Emit tool usage started event @@ -1936,18 +1950,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): track_delegation_if_needed(func_name, args_dict, self.task) - structured_tool: CrewStructuredTool | None = None - if original_tool is not None: - for structured in self.tools or []: - if getattr(structured, "_original_tool", None) is original_tool: - structured_tool = structured - break - if structured_tool is None: - for structured in self.tools or []: - if sanitize_tool_name(structured.name) == func_name: - structured_tool = structured - break - hook_blocked = False before_hook_context = ToolCallHookContext( tool_name=func_name, @@ -1973,12 +1975,13 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): if hook_blocked: result = f"Tool execution blocked by hook. Tool: {func_name}" - elif not from_cache and not max_usage_reached: - result = "Tool not found" + raw_tool_result = result + elif not from_cache and not max_usage_reached and output_tool is not None: if func_name in self._available_functions: try: tool_func = self._available_functions[func_name] raw_result = tool_func(**args_dict) + raw_tool_result = raw_result # Add to cache after successful execution (before string conversion) if self.tools_handler and self.tools_handler.cache: @@ -1992,14 +1995,12 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): tool=func_name, input=input_str, output=raw_result ) - # Convert to string for message - result = ( - str(raw_result) - if not isinstance(raw_result, str) - else raw_result + result = format_native_tool_output_for_agent( + output_tool, raw_result ) except Exception as e: result = f"Error executing tool: {e}" + raw_tool_result = result if self.task: self.task.increment_tools_errors() # Emit tool usage error event @@ -2021,6 +2022,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore." else: result = f"Tool '{func_name}' has reached its maximum usage limit and cannot be used anymore." + raw_tool_result = result # Execute after_tool_call hooks (even if blocked, to allow logging/monitoring) after_hook_context = ToolCallHookContext( @@ -2031,6 +2033,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor): task=self.task, crew=self.crew, tool_result=result, + raw_tool_result=raw_tool_result, ) after_hooks = get_after_tool_call_hooks() try: diff --git a/lib/crewai/src/crewai/hooks/tool_hooks.py b/lib/crewai/src/crewai/hooks/tool_hooks.py index b54ca7b10..a860c01d4 100644 --- a/lib/crewai/src/crewai/hooks/tool_hooks.py +++ b/lib/crewai/src/crewai/hooks/tool_hooks.py @@ -40,6 +40,8 @@ class ToolCallHookContext: crew: Crew instance (may be None) tool_result: Tool execution result (only set for after_tool_call hooks). Can be modified by returning a new string from after_tool_call hook. + raw_tool_result: Raw Python tool execution result (only set for + after_tool_call hooks). This is not modified by after hooks. """ def __init__( @@ -51,6 +53,7 @@ class ToolCallHookContext: task: Task | None = None, crew: Crew | None = None, tool_result: str | None = None, + raw_tool_result: Any | None = None, ) -> None: """Initialize tool call hook context. @@ -62,6 +65,7 @@ class ToolCallHookContext: task: Optional current task crew: Optional crew instance tool_result: Optional tool result (for after hooks) + raw_tool_result: Optional raw tool result (for after hooks) """ self.tool_name = tool_name self.tool_input = tool_input @@ -70,6 +74,7 @@ class ToolCallHookContext: self.task = task self.crew = crew self.tool_result = tool_result + self.raw_tool_result = raw_tool_result def request_human_input( self, diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 31c5009bd..c6c3dba15 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -33,6 +33,8 @@ from typing_extensions import TypeIs from crewai.tools.structured_tool import ( CrewStructuredTool, _deserialize_schema, + _format_tool_output_for_agent, + _infer_result_schema_from_callable, _serialize_schema, build_schema_hint, ) @@ -149,6 +151,11 @@ class BaseTool(BaseModel, ABC): validate_default=True, description="The schema for the arguments that the tool accepts.", ) + result_schema: type[PydanticBaseModel] | None = Field( + default=None, + validate_default=True, + description="The schema for the output that the tool returns.", + ) @field_serializer("args_schema", when_used="json") def _serialize_args_schema( @@ -156,6 +163,12 @@ class BaseTool(BaseModel, ABC): ) -> dict[str, Any] | None: return _serialize_schema(schema) + @field_serializer("result_schema", when_used="json") + def _serialize_result_schema( + self, schema: type[PydanticBaseModel] | None + ) -> dict[str, Any] | None: + return _serialize_schema(schema) + description_updated: bool = Field( default=False, description="Flag to check if the description has been updated." ) @@ -233,6 +246,17 @@ class BaseTool(BaseModel, ABC): return create_model(f"{cls.__name__}Schema", **fields) + @field_validator("result_schema", mode="before") + @classmethod + def _default_result_schema( + cls, v: type[PydanticBaseModel] | dict[str, Any] | None + ) -> type[PydanticBaseModel] | None: + if isinstance(v, dict): + return _deserialize_schema(v) + if v is not None: + return v + return _infer_result_schema_from_callable(cls._run) + @field_validator("max_usage_count", mode="before") @classmethod def validate_max_usage_count(cls, v: int | None) -> int | None: @@ -340,6 +364,10 @@ class BaseTool(BaseModel, ABC): "Override _arun for async support or use run() for sync execution." ) + def format_output_for_agent(self, raw_result: Any) -> str: + """Format a raw tool result into the string representation sent to an agent.""" + return _format_tool_output_for_agent(self, raw_result) + def reset_usage_count(self) -> None: """Reset the current usage count to zero.""" self.current_usage_count = 0 @@ -369,6 +397,7 @@ class BaseTool(BaseModel, ABC): name=self.name, description=self.description, args_schema=self.args_schema, + result_schema=self.result_schema, func=self._run, result_as_answer=self.result_as_answer, max_usage_count=self.max_usage_count, @@ -390,6 +419,9 @@ class BaseTool(BaseModel, ABC): raise ValueError("The provided tool must have a callable 'func' attribute.") args_schema = getattr(tool, "args_schema", None) + result_schema = getattr(tool, "result_schema", None) + if result_schema is None: + result_schema = _infer_result_schema_from_callable(tool.func) if args_schema is None: func_signature = signature(tool.func) @@ -420,6 +452,7 @@ class BaseTool(BaseModel, ABC): description=getattr(tool, "description", ""), func=tool.func, args_schema=args_schema, + result_schema=result_schema, ) def _set_args_schema(self) -> None: @@ -568,6 +601,9 @@ class Tool(BaseTool, Generic[P, R]): raise ValueError("The provided tool must have a callable 'func' attribute.") args_schema = getattr(tool, "args_schema", None) + result_schema = getattr(tool, "result_schema", None) + if result_schema is None: + result_schema = _infer_result_schema_from_callable(tool.func) if args_schema is None: func_signature = signature(tool.func) @@ -598,6 +634,7 @@ class Tool(BaseTool, Generic[P, R]): description=getattr(tool, "description", ""), func=tool.func, args_schema=args_schema, + result_schema=result_schema, ) @@ -621,6 +658,7 @@ def tool( name: str, /, *, + result_schema: type[BaseModel] | None = ..., result_as_answer: bool = ..., max_usage_count: int | None = ..., ) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ... @@ -629,6 +667,7 @@ def tool( @overload def tool( *, + result_schema: type[BaseModel] | None = ..., result_as_answer: bool = ..., max_usage_count: int | None = ..., ) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ... @@ -636,6 +675,7 @@ def tool( def tool( *args: Callable[P2, R2] | str, + result_schema: type[BaseModel] | None = None, result_as_answer: bool = False, max_usage_count: int | None = None, ) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]: @@ -649,6 +689,7 @@ def tool( Args: *args: Either the function to decorate or a custom tool name. result_as_answer: If True, the tool result becomes the final agent answer. + result_schema: Optional schema for the output that the tool returns. max_usage_count: Maximum times this tool can be used. None means unlimited. Returns: @@ -690,12 +731,16 @@ def tool( class_name = "".join(tool_name.split()).title() args_schema = create_model(class_name, **fields) + resolved_result_schema = ( + result_schema or _infer_result_schema_from_callable(f) + ) return Tool( name=tool_name, description=f.__doc__, func=f, args_schema=args_schema, + result_schema=resolved_result_schema, result_as_answer=result_as_answer, max_usage_count=max_usage_count, current_usage_count=0, diff --git a/lib/crewai/src/crewai/tools/structured_tool.py b/lib/crewai/src/crewai/tools/structured_tool.py index 6c24f52dc..8ecba8549 100644 --- a/lib/crewai/src/crewai/tools/structured_tool.py +++ b/lib/crewai/src/crewai/tools/structured_tool.py @@ -5,7 +5,8 @@ from collections.abc import Callable import inspect import json import textwrap -from typing import TYPE_CHECKING, Annotated, Any, get_type_hints +from typing import TYPE_CHECKING, Annotated, Any, cast, get_type_hints +import warnings from pydantic import ( BaseModel, @@ -36,6 +37,52 @@ def _deserialize_schema(v: Any) -> type[BaseModel] | None: return None +def _infer_result_schema_from_callable( + func: Callable[..., Any], +) -> type[BaseModel] | None: + try: + return_annotation = get_type_hints(func).get("return", inspect.Signature.empty) + except Exception: + return_annotation = inspect.signature(func).return_annotation + + if isinstance(return_annotation, type) and issubclass(return_annotation, BaseModel): + return return_annotation + + return None + + +def _format_tool_output_for_agent(tool: Any, raw_result: Any) -> str: + original_tool = getattr(tool, "_original_tool", None) + if original_tool is not None: + return cast(str, original_tool.format_output_for_agent(raw_result)) + + result_schema = getattr(tool, "result_schema", None) + if not (isinstance(result_schema, type) and issubclass(result_schema, BaseModel)): + return str(raw_result) + + try: + validation_input = raw_result + if isinstance(raw_result, BaseModel) and not isinstance( + raw_result, result_schema + ): + validation_input = raw_result.model_dump() + + validated = result_schema.model_validate(validation_input) + return validated.model_dump_json() + except Exception as exc: + warnings.warn( + ( + f"Failed to validate or serialize output from tool " + f"'{getattr(tool, 'name', '')}' using result_schema " + f"'{result_schema.__name__}': {exc.__class__.__name__}. " + "Falling back to str(raw_result)." + ), + RuntimeWarning, + stacklevel=2, + ) + return str(raw_result) + + if TYPE_CHECKING: pass @@ -81,6 +128,11 @@ class CrewStructuredTool(BaseModel): BeforeValidator(_deserialize_schema), PlainSerializer(_serialize_schema), ] = Field(default=None) + result_schema: Annotated[ + type[BaseModel] | None, + BeforeValidator(_deserialize_schema), + PlainSerializer(_serialize_schema), + ] = Field(default=None) func: Any = Field(default=None, exclude=True) result_as_answer: bool = Field(default=False) max_usage_count: int | None = Field(default=None) @@ -103,6 +155,7 @@ class CrewStructuredTool(BaseModel): description: str | None = None, return_direct: bool = False, args_schema: type[BaseModel] | None = None, + result_schema: type[BaseModel] | None = None, infer_schema: bool = True, **kwargs: Any, ) -> CrewStructuredTool: @@ -114,6 +167,7 @@ class CrewStructuredTool(BaseModel): description: The description of the tool. Defaults to the function docstring return_direct: Whether to return the output directly args_schema: Optional schema for the function arguments + result_schema: Optional schema for the function output infer_schema: Whether to infer the schema from the function signature **kwargs: Additional arguments to pass to the tool @@ -149,10 +203,16 @@ class CrewStructuredTool(BaseModel): name=name, description=description, args_schema=schema, + result_schema=result_schema or _infer_result_schema_from_callable(func), func=func, result_as_answer=return_direct, + **kwargs, ) + def format_output_for_agent(self, raw_result: Any) -> str: + """Format a raw tool result into the string representation sent to an agent.""" + return _format_tool_output_for_agent(self, raw_result) + @staticmethod def _create_schema_from_function( name: str, diff --git a/lib/crewai/src/crewai/tools/tool_usage.py b/lib/crewai/src/crewai/tools/tool_usage.py index b34921839..e92ba03ee 100644 --- a/lib/crewai/src/crewai/tools/tool_usage.py +++ b/lib/crewai/src/crewai/tools/tool_usage.py @@ -62,6 +62,9 @@ OPENAI_BIGGER_MODELS: list[ ] +_RAW_RESULT_UNSET = object() + + class ToolUsageError(Exception): """Exception raised for errors in the tool usage.""" @@ -106,6 +109,7 @@ class ToolUsage: self.action = action self.function_calling_llm = function_calling_llm self.fingerprint_context = fingerprint_context or {} + self.last_raw_result: Any = _RAW_RESULT_UNSET if ( self.function_calling_llm @@ -120,6 +124,11 @@ class ToolUsage: """Parse the tool string and return the tool calling.""" return self._tool_calling(tool_string) + def get_last_raw_result(self, fallback: Any) -> Any: + if self.last_raw_result is _RAW_RESULT_UNSET: + return fallback + return self.last_raw_result + def use( self, calling: ToolCalling | InstructorToolCalling, tool_string: str ) -> str: @@ -231,6 +240,7 @@ class ToolUsage: result = I18N_DEFAULT.errors("task_repeated_usage").format( tool_names=self.tools_names ) + self.last_raw_result = result self._telemetry.tool_repeated_usage( llm=self.function_calling_llm, tool_name=sanitize_tool_name(tool.name), @@ -298,6 +308,7 @@ class ToolUsage: ) if usage_limit_error: result = usage_limit_error + self.last_raw_result = result self._telemetry.tool_usage_error(llm=self.function_calling_llm) result = self._format_result(result=result) elif result is None: @@ -359,7 +370,10 @@ class ToolUsage: tool_name=sanitize_tool_name(tool.name), attempts=self._run_attempts, ) - result = self._format_result(result=result) + self.last_raw_result = result + result = self._format_result( + result=tool.format_output_for_agent(result) + ) data = { "result": result, "tool_name": sanitize_tool_name(tool.name), @@ -421,6 +435,7 @@ class ToolUsage: result = ToolUsageError( f"\n{error_message}.\nMoving on then. {I18N_DEFAULT.slice('format').format(tool_names=self.tools_names)}" ).message + self.last_raw_result = result if self.task: self.task.increment_tools_errors() if self.agent and self.agent.verbose: @@ -430,7 +445,10 @@ class ToolUsage: self.task.increment_tools_errors() should_retry = True else: - result = self._format_result(result=result) + self.last_raw_result = result + result = self._format_result( + result=tool.format_output_for_agent(result) + ) finally: if started_event_emitted and not error_event_emitted: @@ -460,6 +478,7 @@ class ToolUsage: result = I18N_DEFAULT.errors("task_repeated_usage").format( tool_names=self.tools_names ) + self.last_raw_result = result self._telemetry.tool_repeated_usage( llm=self.function_calling_llm, tool_name=sanitize_tool_name(tool.name), @@ -529,6 +548,7 @@ class ToolUsage: ) if usage_limit_error: result = usage_limit_error + self.last_raw_result = result self._telemetry.tool_usage_error(llm=self.function_calling_llm) result = self._format_result(result=result) elif result is None: @@ -590,7 +610,10 @@ class ToolUsage: tool_name=sanitize_tool_name(tool.name), attempts=self._run_attempts, ) - result = self._format_result(result=result) + self.last_raw_result = result + result = self._format_result( + result=tool.format_output_for_agent(result) + ) data = { "result": result, "tool_name": sanitize_tool_name(tool.name), @@ -652,6 +675,7 @@ class ToolUsage: result = ToolUsageError( f"\n{error_message}.\nMoving on then. {I18N_DEFAULT.slice('format').format(tool_names=self.tools_names)}" ).message + self.last_raw_result = result if self.task: self.task.increment_tools_errors() if self.agent and self.agent.verbose: @@ -661,7 +685,10 @@ class ToolUsage: self.task.increment_tools_errors() should_retry = True else: - result = self._format_result(result=result) + self.last_raw_result = result + result = self._format_result( + result=tool.format_output_for_agent(result) + ) finally: if started_event_emitted and not error_event_emitted: diff --git a/lib/crewai/src/crewai/utilities/agent_utils.py b/lib/crewai/src/crewai/utilities/agent_utils.py index 80f8ab242..e933a38a8 100644 --- a/lib/crewai/src/crewai/utilities/agent_utils.py +++ b/lib/crewai/src/crewai/utilities/agent_utils.py @@ -1383,6 +1383,19 @@ class NativeToolCallResult: tool_message: LLMMessage = field(default_factory=dict) # type: ignore[assignment] +def format_native_tool_output_for_agent(tool: Any, raw_result: Any) -> str: + """Format native tool output when a tool explicitly defines a formatter.""" + formatter = inspect.getattr_static(tool, "format_output_for_agent", None) + if formatter is None: + return str(raw_result) + + runtime_formatter = getattr(tool, "format_output_for_agent", None) + if not callable(runtime_formatter): + return str(raw_result) + + return str(runtime_formatter(raw_result)) + + def execute_single_native_tool_call( tool_call: Any, *, @@ -1456,18 +1469,24 @@ def execute_single_native_tool_call( original_tool = tool break + structured_tool: CrewStructuredTool | None = None + for structured in structured_tools or []: + if sanitize_tool_name(structured.name) == func_name: + structured_tool = structured + break + + output_tool = original_tool or structured_tool + from_cache = False input_str = json.dumps(args_dict) if args_dict else "" result = "Tool not found" + raw_tool_result: Any = result - if tools_handler and tools_handler.cache: + if tools_handler and tools_handler.cache and output_tool is not None: cached_result = tools_handler.cache.read(tool=func_name, input=input_str) if cached_result is not None: - result = ( - str(cached_result) - if not isinstance(cached_result, str) - else cached_result - ) + raw_tool_result = cached_result + result = format_native_tool_output_for_agent(output_tool, cached_result) from_cache = True started_at = datetime.now() @@ -1486,12 +1505,6 @@ def execute_single_native_tool_call( track_delegation_if_needed(func_name, args_dict, task) - structured_tool: CrewStructuredTool | None = None - for structured in structured_tools or []: - if sanitize_tool_name(structured.name) == func_name: - structured_tool = structured - break - hook_blocked = False before_hook_context = ToolCallHookContext( tool_name=func_name, @@ -1512,11 +1525,13 @@ def execute_single_native_tool_call( error_event_emitted = False if hook_blocked: result = f"Tool execution blocked by hook. Tool: {func_name}" + raw_tool_result = result elif not from_cache: - if func_name in available_functions: + if func_name in available_functions and output_tool is not None: try: tool_func = available_functions[func_name] raw_result = tool_func(**args_dict) + raw_tool_result = raw_result if tools_handler and tools_handler.cache: should_cache = True @@ -1529,11 +1544,10 @@ def execute_single_native_tool_call( tool=func_name, input=input_str, output=raw_result ) - result = ( - str(raw_result) if not isinstance(raw_result, str) else raw_result - ) + result = format_native_tool_output_for_agent(output_tool, raw_result) except Exception as e: result = f"Error executing tool: {e}" + raw_tool_result = result if task: task.increment_tools_errors() crewai_event_bus.emit( @@ -1559,6 +1573,7 @@ def execute_single_native_tool_call( task=task, crew=crew, tool_result=result, + raw_tool_result=raw_tool_result, ) try: for after_hook in get_after_tool_call_hooks(): diff --git a/lib/crewai/src/crewai/utilities/tool_utils.py b/lib/crewai/src/crewai/utilities/tool_utils.py index e19c3c81a..dcb25594c 100644 --- a/lib/crewai/src/crewai/utilities/tool_utils.py +++ b/lib/crewai/src/crewai/utilities/tool_utils.py @@ -116,6 +116,7 @@ async def aexecute_tool_and_check_finality( logger.log("error", f"Error in before_tool_call hook: {e}") tool_result = await tool_usage.ause(tool_calling, agent_action.text) + raw_tool_result = tool_usage.get_last_raw_result(tool_result) after_hook_context = ToolCallHookContext( tool_name=sanitized_tool_name, @@ -125,6 +126,7 @@ async def aexecute_tool_and_check_finality( task=task, crew=crew, tool_result=tool_result, + raw_tool_result=raw_tool_result, ) after_hooks = get_after_tool_call_hooks() @@ -234,6 +236,7 @@ def execute_tool_and_check_finality( logger.log("error", f"Error in before_tool_call hook: {e}") tool_result = tool_usage.use(tool_calling, agent_action.text) + raw_tool_result = tool_usage.get_last_raw_result(tool_result) after_hook_context = ToolCallHookContext( tool_name=sanitized_tool_name, @@ -243,6 +246,7 @@ def execute_tool_and_check_finality( task=task, crew=crew, tool_result=tool_result, + raw_tool_result=raw_tool_result, ) after_hooks = get_after_tool_call_hooks() diff --git a/lib/crewai/tests/agents/test_native_tool_calling.py b/lib/crewai/tests/agents/test_native_tool_calling.py index b7e0df199..894c0bd45 100644 --- a/lib/crewai/tests/agents/test_native_tool_calling.py +++ b/lib/crewai/tests/agents/test_native_tool_calling.py @@ -7,6 +7,7 @@ when the LLM supports it, across multiple providers. from __future__ import annotations from collections.abc import Generator +import json import os import threading import time @@ -20,7 +21,7 @@ from crewai import Agent, Crew, Task from crewai.agents.parser import AgentFinish from crewai.events import crewai_event_bus from crewai.hooks import register_after_tool_call_hook, register_before_tool_call_hook -from crewai.hooks.tool_hooks import ToolCallHookContext +from crewai.hooks.tool_hooks import ToolCallHookContext, clear_after_tool_call_hooks from crewai.llm import LLM from crewai.tools.base_tool import BaseTool @@ -1197,6 +1198,76 @@ class TestNativeToolCallingJsonParseError: assert result["result"] == "ran: print(1)" + def test_typed_output_is_json_agent_text(self) -> None: + class SearchOutput(BaseModel): + query: str + score: float + + class TypedSearchTool(BaseTool): + name: str = "typed_search" + description: str = "Search for information" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.8) + + tool = TypedSearchTool() + executor = self._make_executor([tool]) + + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + + _, available_functions, _ = convert_tools_to_openai_schema([tool]) + + result = executor._execute_single_native_tool_call( + call_id="call_typed", + func_name="typed_search", + func_args='{"query": "crew"}', + available_functions=available_functions, + ) + + assert json.loads(result["result"]) == {"query": "crew", "score": 0.8} + + def test_typed_output_after_hook_includes_raw_tool_result(self) -> None: + from crewai.utilities.agent_utils import convert_tools_to_openai_schema + + class SearchOutput(BaseModel): + query: str + score: float + + class TypedSearchTool(BaseTool): + name: str = "typed_search" + description: str = "Search for information" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.8) + + seen_results: list[tuple[str | None, object]] = [] + + def after_hook(context: ToolCallHookContext) -> None: + seen_results.append((context.tool_result, context.raw_tool_result)) + + tool = TypedSearchTool() + executor = self._make_executor([tool]) + _, available_functions, _ = convert_tools_to_openai_schema([tool]) + + clear_after_tool_call_hooks() + register_after_tool_call_hook(after_hook) + try: + result = executor._execute_single_native_tool_call( + call_id="call_typed", + func_name="typed_search", + func_args='{"query": "crew"}', + available_functions=available_functions, + ) + finally: + clear_after_tool_call_hooks() + + assert json.loads(result["result"]) == {"query": "crew", "score": 0.8} + assert seen_results == [ + ('{"query":"crew","score":0.8}', SearchOutput(query="crew", score=0.8)) + ] + def test_native_tool_loop_falls_back_when_provider_rejects_tools(self) -> None: """Unsupported native tools errors should continue through ReAct.""" diff --git a/lib/crewai/tests/hooks/test_tool_hooks.py b/lib/crewai/tests/hooks/test_tool_hooks.py index 347eb56e5..14c6848a8 100644 --- a/lib/crewai/tests/hooks/test_tool_hooks.py +++ b/lib/crewai/tests/hooks/test_tool_hooks.py @@ -91,20 +91,24 @@ class TestToolCallHookContext: assert context.task == mock_task assert context.crew == mock_crew assert context.tool_result is None + assert context.raw_tool_result is None def test_context_with_result(self, mock_tool): """Test that context includes result when provided.""" tool_input = {"arg1": "value1"} tool_result = "Test tool result" + raw_tool_result = {"value": 42} context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, tool_result=tool_result, + raw_tool_result=raw_tool_result, ) assert context.tool_result == tool_result + assert context.raw_tool_result == raw_tool_result def test_tool_input_is_mutable_reference(self, mock_tool): """Test that modifying context.tool_input modifies the original dict.""" diff --git a/lib/crewai/tests/tools/test_base_tool.py b/lib/crewai/tests/tools/test_base_tool.py index 7648ad73b..52661fffc 100644 --- a/lib/crewai/tests/tools/test_base_tool.py +++ b/lib/crewai/tests/tools/test_base_tool.py @@ -1,12 +1,13 @@ import asyncio from collections.abc import Callable +import json from unittest.mock import patch from crewai.agent import Agent from crewai.crew import Crew from crewai.task import Task from crewai.tools import BaseTool, tool -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, RootModel import pytest @@ -351,6 +352,262 @@ class TestToolDecoratorRunValidation: assert result == "Hello, World!" +class SearchOutput(BaseModel): + query: str + score: float + + +class SearchResults(RootModel[list[SearchOutput]]): + pass + + +class ExplicitSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> dict[str, object]: + return {"query": query, "score": 0.8} + + +class InferredSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.7) + + +class RootSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + + def _run(self, query: str) -> SearchResults: + return SearchResults([SearchOutput(query=query, score=1.0)]) + + +class DictAnnotatedSearchTool(BaseTool): + name: str = "search" + description: str = "Search for a query" + + def _run(self, query: str) -> dict[str, object]: + return {"query": query, "score": 0.5} + + +def _make_explicit_decorator_tool() -> BaseTool: + @tool("search", result_schema=SearchOutput) + def search(query: str) -> dict[str, object]: + """Search for a query.""" + return {"query": query, "score": 0.8} + + return search + + +def _make_inferred_decorator_tool() -> BaseTool: + @tool("search") + def search(query: str) -> SearchOutput: + """Search for a query.""" + return SearchOutput(query=query, score=0.6) + + return search + + +def _make_root_decorator_tool() -> BaseTool: + @tool("search") + def search(query: str) -> SearchResults: + """Search for a query.""" + return SearchResults([SearchOutput(query=query, score=1.0)]) + + return search + + +class TestToolOutputSchema: + @pytest.mark.parametrize( + ("tool_cls", "expected_raw", "expected_agent_payload"), + [ + pytest.param( + ExplicitSearchTool, + {"query": "crew", "score": 0.8}, + {"query": "crew", "score": 0.8}, + id="explicit-schema", + ), + pytest.param( + InferredSearchTool, + SearchOutput(query="crew", score=0.7), + {"query": "crew", "score": 0.7}, + id="inferred-base-model", + ), + pytest.param( + RootSearchTool, + SearchResults([SearchOutput(query="crew", score=1.0)]), + [{"query": "crew", "score": 1.0}], + id="inferred-root-model", + ), + ], + ) + def test_base_tools_return_raw_result_and_json_agent_text( + self, + tool_cls: type[BaseTool], + expected_raw: object, + expected_agent_payload: object, + ) -> None: + t = tool_cls() + + raw_result = t.run(query="crew") + + assert raw_result == expected_raw + assert json.loads(t.format_output_for_agent(raw_result)) == ( + expected_agent_payload + ) + + def test_base_tool_does_not_infer_non_pydantic_return_annotation(self) -> None: + t = DictAnnotatedSearchTool() + + raw_result = t.run(query="crew") + + assert raw_result == {"query": "crew", "score": 0.5} + assert t.format_output_for_agent(raw_result) == str(raw_result) + + @pytest.mark.parametrize( + ("make_tool", "expected_raw", "expected_agent_payload"), + [ + pytest.param( + _make_explicit_decorator_tool, + {"query": "crew", "score": 0.8}, + {"query": "crew", "score": 0.8}, + id="explicit-schema", + ), + pytest.param( + _make_inferred_decorator_tool, + SearchOutput(query="crew", score=0.6), + {"query": "crew", "score": 0.6}, + id="inferred-base-model", + ), + pytest.param( + _make_root_decorator_tool, + SearchResults([SearchOutput(query="crew", score=1.0)]), + [{"query": "crew", "score": 1.0}], + id="inferred-root-model", + ), + ], + ) + def test_decorator_tools_return_raw_result_and_json_agent_text( + self, + make_tool: Callable[[], BaseTool], + expected_raw: object, + expected_agent_payload: object, + ) -> None: + search = make_tool() + + raw_result = search.run(query="crew") + + assert raw_result == expected_raw + assert json.loads(search.format_output_for_agent(raw_result)) == ( + expected_agent_payload + ) + + def test_decorator_tool_does_not_infer_non_pydantic_return_annotation( + self, + ) -> None: + @tool("search") + def search(query: str) -> dict[str, object]: + """Search for a query.""" + return {"query": query, "score": 0.5} + + raw_result = search.run(query="crew") + + assert raw_result == {"query": "crew", "score": 0.5} + assert search.format_output_for_agent(raw_result) == str(raw_result) + + def test_explicit_result_schema_wins_over_return_annotation(self) -> None: + class AlternateOutput(BaseModel): + value: str + + @tool("search", result_schema=AlternateOutput) + def search(query: str) -> SearchOutput: + """Search for a query.""" + return SearchOutput(query=query, score=0.6) + + raw_result = search.run(query="crew") + + with pytest.warns(RuntimeWarning, match="AlternateOutput"): + agent_text = search.format_output_for_agent(raw_result) + + assert raw_result == SearchOutput(query="crew", score=0.6) + assert agent_text == str(raw_result) + + def test_invalid_typed_output_warns_and_uses_string_agent_text( + self, + ) -> None: + @tool("search", result_schema=SearchOutput) + def search(query: str) -> dict[str, object]: + """Search for a query.""" + return {"query": query, "score": "not-a-float"} + + raw_result = search.run(query="crew") + + with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"): + agent_text = search.format_output_for_agent(raw_result) + + assert raw_result == {"query": "crew", "score": "not-a-float"} + assert agent_text == str(raw_result) + + def test_unserializable_typed_output_warns_and_uses_string_agent_text( + self, + ) -> None: + class OpaqueOutput(BaseModel): + value: object + + raw_result = OpaqueOutput(value=object()) + + @tool("opaque", result_schema=OpaqueOutput) + def opaque() -> OpaqueOutput: + """Return an opaque object.""" + return raw_result + + result = opaque.run() + + with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"): + agent_text = opaque.format_output_for_agent(result) + + assert result is raw_result + assert agent_text == str(raw_result) + + def test_result_schema_behavior_carries_over_to_structured_tool(self) -> None: + structured = ExplicitSearchTool().to_structured_tool() + + raw_result = structured.invoke({"query": "crew"}) + + assert raw_result == {"query": "crew", "score": 0.8} + assert json.loads(structured.format_output_for_agent(raw_result)) == { + "query": "crew", + "score": 0.8, + } + + def test_custom_agent_output_formatter_carries_over_to_structured_tool( + self, + ) -> None: + class MarkdownSearchTool(BaseTool): + name: str = "markdown_search" + description: str = "Search for information" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.8) + + def format_output_for_agent(self, raw_result: object) -> str: + result = self.result_schema.model_validate(raw_result) + return f"### Search result\n\n- Query: `{result.query}`\n- Score: {result.score}" + + structured = MarkdownSearchTool().to_structured_tool() + + raw_result = structured.invoke({"query": "crew"}) + + assert raw_result == SearchOutput(query="crew", score=0.8) + assert structured.format_output_for_agent(raw_result) == ( + "### Search result\n\n- Query: `crew`\n- Score: 0.8" + ) + # Async arun() Schema Validation Tests diff --git a/lib/crewai/tests/tools/test_structured_tool.py b/lib/crewai/tests/tools/test_structured_tool.py index 27c463d47..0241abbcf 100644 --- a/lib/crewai/tests/tools/test_structured_tool.py +++ b/lib/crewai/tests/tools/test_structured_tool.py @@ -1,5 +1,7 @@ +import json + from crewai.tools.structured_tool import CrewStructuredTool -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, RootModel import pytest @@ -86,6 +88,118 @@ def test_from_function(basic_function): assert isinstance(tool.args_schema, type(BaseModel)) +class StructuredOutput(BaseModel): + value: str + count: int + + +class StructuredOutputList(RootModel[list[StructuredOutput]]): + pass + + +def _build_explicit_structured_value(value: str) -> dict[str, object]: + """Build a value.""" + return {"value": value, "count": 1} + + +def _build_inferred_structured_value(value: str) -> StructuredOutput: + """Build a value.""" + return StructuredOutput(value=value, count=1) + + +def _build_structured_values(value: str) -> StructuredOutputList: + """Build values.""" + return StructuredOutputList([StructuredOutput(value=value, count=1)]) + + +def _build_plain_structured_value(value: str) -> dict[str, object]: + """Build a value.""" + return {"value": value, "count": 1} + + +@pytest.mark.parametrize( + ("func", "result_schema", "expected_raw", "expected_agent_payload"), + [ + pytest.param( + _build_explicit_structured_value, + StructuredOutput, + {"value": "crew", "count": 1}, + {"value": "crew", "count": 1}, + id="explicit-schema", + ), + pytest.param( + _build_inferred_structured_value, + None, + StructuredOutput(value="crew", count=1), + {"value": "crew", "count": 1}, + id="inferred-base-model", + ), + pytest.param( + _build_structured_values, + None, + StructuredOutputList([StructuredOutput(value="crew", count=1)]), + [{"value": "crew", "count": 1}], + id="inferred-root-model", + ), + ], +) +def test_from_function_returns_raw_result_and_json_agent_text( + func, + result_schema, + expected_raw, + expected_agent_payload, +): + kwargs = {"result_schema": result_schema} if result_schema is not None else {} + tool = CrewStructuredTool.from_function( + func=func, + name="build_value", + **kwargs, + ) + + raw_result = tool.invoke({"value": "crew"}) + + assert raw_result == expected_raw + assert json.loads(tool.format_output_for_agent(raw_result)) == ( + expected_agent_payload + ) + + +def test_from_function_does_not_infer_non_pydantic_result_schema(): + tool = CrewStructuredTool.from_function( + func=_build_plain_structured_value, + name="build_value", + ) + + raw_result = tool.invoke({"value": "crew"}) + + assert raw_result == {"value": "crew", "count": 1} + assert tool.format_output_for_agent(raw_result) == str(raw_result) + + +def test_invalid_typed_output_warns_and_uses_string_agent_text(): + def build_value(value: str) -> dict[str, object]: + """Build a value.""" + return {"value": value, "count": "wrong"} + + tool = CrewStructuredTool.from_function( + func=build_value, + name="build_value", + result_schema=StructuredOutput, + ) + raw_result = tool.invoke({"value": "crew"}) + + with pytest.warns( + RuntimeWarning, match="Failed to validate or serialize" + ) as warnings: + agent_text = tool.format_output_for_agent(raw_result) + + assert raw_result == {"value": "crew", "count": "wrong"} + assert agent_text == str(raw_result) + warning_message = str(warnings[0].message) + assert "ValidationError" in warning_message + assert "wrong" not in warning_message + + def test_validate_function_signature(basic_function, schema_class): """Test function signature validation""" tool = CrewStructuredTool( diff --git a/lib/crewai/tests/tools/test_tool_usage.py b/lib/crewai/tests/tools/test_tool_usage.py index ba4fe72dd..9d61c93a9 100644 --- a/lib/crewai/tests/tools/test_tool_usage.py +++ b/lib/crewai/tests/tools/test_tool_usage.py @@ -1,4 +1,5 @@ import datetime +from collections.abc import Callable import json import random import threading @@ -6,6 +7,9 @@ import time from unittest.mock import MagicMock, patch from crewai import Agent, Task +from crewai.agents.cache.cache_handler import CacheHandler +from crewai.agents.parser import AgentAction +from crewai.agents.tools_handler import ToolsHandler from crewai.events.event_bus import crewai_event_bus from crewai.events.types.tool_usage_events import ( ToolSelectionErrorEvent, @@ -14,8 +18,15 @@ from crewai.events.types.tool_usage_events import ( ToolUsageStartedEvent, ToolValidateInputErrorEvent, ) +from crewai.hooks.tool_hooks import ( + ToolCallHookContext, + clear_after_tool_call_hooks, + register_after_tool_call_hook, +) from crewai.tools import BaseTool +from crewai.tools.tool_calling import ToolCalling from crewai.tools.tool_usage import ToolUsage +from crewai.utilities.tool_utils import execute_tool_and_check_finality from pydantic import BaseModel, Field import pytest @@ -38,6 +49,19 @@ class RandomNumberTool(BaseTool): return random.randint(min_value, max_value) # noqa: S311 +class SearchOutput(BaseModel): + query: str + score: float + + +class TypedSearchTool(BaseTool): + name: str = "typed_search" + description: str = "Search for a query" + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.7) + + # Example agent and task example_agent = Agent( role="Number Generator", @@ -117,6 +141,126 @@ def test_tool_usage_render(): assert '"description": "The maximum value of the range (inclusive)"' in rendered +def test_tool_usage_returns_json_agent_text_for_typed_output(): + tool = TypedSearchTool().to_structured_tool() + tool_usage = ToolUsage( + tools_handler=None, + tools=[tool], + task=None, + function_calling_llm=MagicMock(), + agent=None, + action=MagicMock(), + ) + + result = tool_usage.use( + calling=ToolCalling( + tool_name="typed_search", + arguments={"query": "crew"}, + ), + tool_string='Action: typed_search\nAction Input: {"query": "crew"}', + ) + + assert json.loads(result) == {"query": "crew", "score": 0.7} + + +def test_tool_usage_cache_callback_receives_raw_typed_output(): + raw_results: list[object] = [] + + def cache_result(_args: object, result: object) -> bool: + raw_results.append(result) + return True + + class CacheAwareTypedSearchTool(TypedSearchTool): + cache_function: Callable = cache_result + + tools_handler = MagicMock() + tools_handler.cache = None + tools_handler.last_used_tool = None + tool = CacheAwareTypedSearchTool().to_structured_tool() + tool_usage = ToolUsage( + tools_handler=tools_handler, + tools=[tool], + task=None, + function_calling_llm=MagicMock(), + agent=None, + action=MagicMock(), + ) + + result = tool_usage.use( + calling=ToolCalling( + tool_name="typed_search", + arguments={"query": "crew"}, + ), + tool_string='Action: typed_search\nAction Input: {"query": "crew"}', + ) + + assert json.loads(result) == {"query": "crew", "score": 0.7} + assert raw_results == [SearchOutput(query="crew", score=0.7)] + tools_handler.on_tool_use.assert_called_once() + assert tools_handler.on_tool_use.call_args.kwargs["output"] == SearchOutput( + query="crew", + score=0.7, + ) + + +def test_react_tool_hooks_receive_agent_text_and_raw_cached_typed_output(): + structured_tool = TypedSearchTool().to_structured_tool() + tools_handler = ToolsHandler(cache=CacheHandler()) + seen_results: list[tuple[str | None, object]] = [] + + def after_hook(context: ToolCallHookContext) -> None: + seen_results.append((context.tool_result, context.raw_tool_result)) + + clear_after_tool_call_hooks() + register_after_tool_call_hook(after_hook) + + action = AgentAction( + thought="", + tool="typed_search", + tool_input='{"query": "crew"}', + text='Action: typed_search\nAction Input: {"query": "crew"}', + ) + + try: + first = execute_tool_and_check_finality( + agent_action=action, + tools=[structured_tool], + tools_handler=tools_handler, + ) + tools_handler.last_used_tool = None + second = execute_tool_and_check_finality( + agent_action=action, + tools=[structured_tool], + tools_handler=tools_handler, + ) + finally: + clear_after_tool_call_hooks() + + assert json.loads(first.result) == {"query": "crew", "score": 0.7} + assert json.loads(second.result) == {"query": "crew", "score": 0.7} + assert seen_results == [ + ('{"query":"crew","score":0.7}', SearchOutput(query="crew", score=0.7)), + ('{"query":"crew","score":0.7}', SearchOutput(query="crew", score=0.7)), + ] + + +def test_last_raw_result_falls_back_only_until_recorded(): + tool_usage = ToolUsage( + tools_handler=None, + tools=[], + task=None, + function_calling_llm=MagicMock(), + agent=None, + action=MagicMock(), + ) + + assert tool_usage.get_last_raw_result("formatted result") == "formatted result" + + tool_usage.last_raw_result = None + + assert tool_usage.get_last_raw_result("formatted result") is None + + def test_validate_tool_input_booleans_and_none(): tool_usage = ToolUsage( tools_handler=MagicMock(), diff --git a/lib/crewai/tests/utilities/test_agent_utils.py b/lib/crewai/tests/utilities/test_agent_utils.py index de3ed411b..9cf4a2d2a 100644 --- a/lib/crewai/tests/utilities/test_agent_utils.py +++ b/lib/crewai/tests/utilities/test_agent_utils.py @@ -3,12 +3,19 @@ from __future__ import annotations import asyncio +import json from typing import Any, Literal, Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import BaseModel, Field +from crewai.hooks.tool_hooks import ( + ToolCallHookContext, + clear_after_tool_call_hooks, + clear_before_tool_call_hooks, + register_after_tool_call_hook, +) from crewai.tools.base_tool import BaseTool from crewai.utilities.agent_utils import ( _asummarize_chunks, @@ -1030,6 +1037,142 @@ class TestParseToolCallArgs: class TestExecuteSingleNativeToolCall: """Tests for execute_single_native_tool_call.""" + def test_typed_tool_output_is_json_agent_text(self) -> None: + clear_before_tool_call_hooks() + clear_after_tool_call_hooks() + + class SearchOutput(BaseModel): + query: str + score: float + + class TypedSearchTool(BaseTool): + name: str = "typed_search" + description: str = "Search for a query" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.9) + + tool = TypedSearchTool() + tool_call = MagicMock() + tool_call.id = "call_1" + tool_call.function.name = "typed_search" + tool_call.function.arguments = '{"query": "crew"}' + + result = execute_single_native_tool_call( + tool_call, + available_functions={"typed_search": tool._run}, + original_tools=[tool], + structured_tools=[tool.to_structured_tool()], + tools_handler=None, + agent=None, + task=None, + crew=None, + event_source=MagicMock(), + printer=None, + verbose=False, + ) + + assert json.loads(result.result) == {"query": "crew", "score": 0.9} + assert json.loads(result.tool_message["content"]) == { + "query": "crew", + "score": 0.9, + } + + def test_custom_agent_output_formatter_is_used_from_structured_tool( + self, + ) -> None: + clear_before_tool_call_hooks() + clear_after_tool_call_hooks() + + class SearchOutput(BaseModel): + query: str + score: float + + class MarkdownSearchTool(BaseTool): + name: str = "markdown_search" + description: str = "Search for a query" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.9) + + def format_output_for_agent(self, raw_result: Any) -> str: + result = self.result_schema.model_validate(raw_result) + return f"### {result.query}\n\nScore: **{result.score}**" + + tool = MarkdownSearchTool() + tool_call = MagicMock() + tool_call.id = "call_1" + tool_call.function.name = "markdown_search" + tool_call.function.arguments = '{"query": "crew"}' + + result = execute_single_native_tool_call( + tool_call, + available_functions={"markdown_search": tool._run}, + original_tools=[], + structured_tools=[tool.to_structured_tool()], + tools_handler=None, + agent=None, + task=None, + crew=None, + event_source=MagicMock(), + printer=None, + verbose=False, + ) + + assert result.result == "### crew\n\nScore: **0.9**" + assert result.tool_message["content"] == "### crew\n\nScore: **0.9**" + + def test_after_hook_includes_raw_tool_result_for_typed_output(self) -> None: + clear_after_tool_call_hooks() + + class SearchOutput(BaseModel): + query: str + score: float + + class TypedSearchTool(BaseTool): + name: str = "typed_search" + description: str = "Search for a query" + result_schema: type[BaseModel] = SearchOutput + + def _run(self, query: str) -> SearchOutput: + return SearchOutput(query=query, score=0.9) + + seen_results: list[tuple[str | None, object]] = [] + + def after_hook(context: ToolCallHookContext) -> None: + seen_results.append((context.tool_result, context.raw_tool_result)) + + tool = TypedSearchTool() + tool_call = MagicMock() + tool_call.id = "call_1" + tool_call.function.name = "typed_search" + tool_call.function.arguments = '{"query": "crew"}' + + register_after_tool_call_hook(after_hook) + try: + result = execute_single_native_tool_call( + tool_call, + available_functions={"typed_search": tool._run}, + original_tools=[tool], + structured_tools=[tool.to_structured_tool()], + tools_handler=None, + agent=None, + task=None, + crew=None, + event_source=MagicMock(), + printer=None, + verbose=False, + ) + finally: + clear_after_tool_call_hooks() + + assert json.loads(result.result) == {"query": "crew", "score": 0.9} + assert seen_results == [ + ('{"query":"crew","score":0.9}', SearchOutput(query="crew", score=0.9)) + ] + def test_result_as_answer_false_on_tool_error(self) -> None: """When a tool with result_as_answer=True raises, result_as_answer must be False.