mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-19 15:18:12 +00:00
Compare commits
10 Commits
fix/creden
...
tools-sche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
754d1a93f3 | ||
|
|
3e74832ec9 | ||
|
|
7d8334268e | ||
|
|
76bcf6d6a7 | ||
|
|
5fbab5a50a | ||
|
|
9b8ecc7df5 | ||
|
|
8c35dedfb5 | ||
|
|
58fc692b03 | ||
|
|
267b519896 | ||
|
|
ba7533ed9d |
@@ -39,6 +39,7 @@ The Enterprise Tools Repository includes:
|
||||
- **Error Handling**: Incorporates robust error handling mechanisms to ensure smooth operation.
|
||||
- **Caching Mechanism**: Features intelligent caching to optimize performance and reduce redundant operations.
|
||||
- **Asynchronous Support**: Handles both synchronous and asynchronous tools, enabling non-blocking operations.
|
||||
- **Typed Outputs**: Uses optional Pydantic models to give agents clear JSON fields while direct Python calls still receive the tool's normal return value.
|
||||
|
||||
## Using CrewAI Tools
|
||||
|
||||
@@ -184,6 +185,55 @@ class MyCustomTool(BaseTool):
|
||||
return "Tool's result"
|
||||
```
|
||||
|
||||
### Typed Tool Outputs
|
||||
|
||||
When a tool returns structured data, define a Pydantic output model. This gives the agent field names it can trust, such as `sku`, `quantity`, or `needs_reorder`.
|
||||
|
||||
Direct Python calls still receive the value your tool returns. When an agent uses the tool, CrewAI sends the agent a JSON string based on the output model.
|
||||
|
||||
```python Code
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
class InventoryResult(BaseModel):
|
||||
sku: str
|
||||
quantity: int
|
||||
needs_reorder: bool
|
||||
|
||||
class InventoryTool(BaseTool):
|
||||
name: str = "Inventory Check"
|
||||
description: str = "Checks current stock for a product SKU."
|
||||
|
||||
def _run(self, sku: str) -> InventoryResult:
|
||||
quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0)
|
||||
return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5)
|
||||
|
||||
tool = InventoryTool()
|
||||
|
||||
# Direct calls receive the raw Pydantic object.
|
||||
result = tool.run(sku="SKU-123")
|
||||
print(result.quantity)
|
||||
```
|
||||
|
||||
To send Markdown or another short text format to the agent, override `format_output_for_agent`. Direct calls to `tool.run(...)` still return the normal Python value.
|
||||
|
||||
```python Code
|
||||
class InventoryTool(BaseTool):
|
||||
name: str = "Inventory Check"
|
||||
description: str = "Checks current stock for a product SKU."
|
||||
|
||||
def _run(self, sku: str) -> InventoryResult:
|
||||
quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0)
|
||||
return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5)
|
||||
|
||||
def format_output_for_agent(self, raw_result: object) -> str:
|
||||
result = InventoryResult.model_validate(raw_result)
|
||||
status = "reorder needed" if result.needs_reorder else "stock is healthy"
|
||||
return f"{result.sku}: {result.quantity} units. {status}."
|
||||
```
|
||||
|
||||
If you do not override `format_output_for_agent`, typed outputs are sent to the agent as JSON. Plain string results work as before.
|
||||
|
||||
## Asynchronous Tool Support
|
||||
|
||||
CrewAI supports asynchronous tools, allowing you to implement tools that perform non-blocking operations like network requests, file I/O, or other async operations without blocking the main execution thread.
|
||||
|
||||
@@ -65,7 +65,7 @@ Regardless of which approach you use, your tool must:
|
||||
- Have a **`description`** — tells the agent when and how to use the tool. This directly affects how well agents use your tool, so be clear and specific.
|
||||
- Implement **`_run`** (BaseTool) or provide a **function body** (@tool) — the synchronous execution logic.
|
||||
- Use **type annotations** on all parameters and return values.
|
||||
- Return a **string** result (or something that can be meaningfully converted to one).
|
||||
- Return a **string** result, or define an optional Pydantic output schema for structured results.
|
||||
|
||||
### Optional: Async Support
|
||||
|
||||
@@ -104,6 +104,67 @@ class TranslateInput(BaseModel):
|
||||
|
||||
Explicit schemas are recommended for published tools — they produce better agent behavior and clearer documentation for your users.
|
||||
|
||||
### Optional: Typed Outputs with `output_schema`
|
||||
|
||||
If your tool returns structured data, define a Pydantic output model. This is a good default for published tools because users and agents can rely on named fields.
|
||||
|
||||
Direct Python calls still receive the value your tool returns. When an agent uses the tool, CrewAI sends the agent JSON based on the output model.
|
||||
|
||||
CrewAI can infer the output schema from a Pydantic return annotation:
|
||||
|
||||
```python
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GeolocateResult(BaseModel):
|
||||
latitude: float = Field(..., description="Latitude in decimal degrees.")
|
||||
longitude: float = Field(..., description="Longitude in decimal degrees.")
|
||||
|
||||
|
||||
class GeolocateTool(BaseTool):
|
||||
name: str = "Geolocate"
|
||||
description: str = "Converts a street address into latitude/longitude coordinates."
|
||||
|
||||
def _run(self, address: str) -> GeolocateResult:
|
||||
if "1600 Pennsylvania" in address:
|
||||
return GeolocateResult(latitude=38.8977, longitude=-77.0365)
|
||||
return GeolocateResult(latitude=40.7128, longitude=-74.0060)
|
||||
```
|
||||
|
||||
Set `output_schema` explicitly when your tool returns a dictionary:
|
||||
|
||||
```python
|
||||
class GeolocateTool(BaseTool):
|
||||
name: str = "Geolocate"
|
||||
description: str = "Converts a street address into latitude/longitude coordinates."
|
||||
output_schema: type[BaseModel] = GeolocateResult
|
||||
|
||||
def _run(self, address: str) -> dict[str, float]:
|
||||
if "1600 Pennsylvania" in address:
|
||||
return {"latitude": 38.8977, "longitude": -77.0365}
|
||||
return {"latitude": 40.7128, "longitude": -74.0060}
|
||||
```
|
||||
|
||||
If agents should receive a short text summary instead of JSON, override `format_output_for_agent` on your `BaseTool` subclass.
|
||||
|
||||
```python
|
||||
class GeolocateTool(BaseTool):
|
||||
name: str = "Geolocate"
|
||||
description: str = "Converts a street address into latitude/longitude coordinates."
|
||||
|
||||
def _run(self, address: str) -> GeolocateResult:
|
||||
if "1600 Pennsylvania" in address:
|
||||
return GeolocateResult(latitude=38.8977, longitude=-77.0365)
|
||||
return GeolocateResult(latitude=40.7128, longitude=-74.0060)
|
||||
|
||||
def format_output_for_agent(self, raw_result: object) -> str:
|
||||
result = GeolocateResult.model_validate(raw_result)
|
||||
return f"Latitude {result.latitude}, longitude {result.longitude}"
|
||||
```
|
||||
|
||||
The override only changes what the agent sees. Direct users of your package still receive the normal value from `tool.run(...)`.
|
||||
|
||||
### Optional: Environment Variables
|
||||
|
||||
If your tool requires API keys or other configuration, declare them with `env_vars` so users know what to set:
|
||||
@@ -241,4 +302,4 @@ agent = Agent(
|
||||
tools=[GeolocateTool()],
|
||||
# ...
|
||||
)
|
||||
```
|
||||
```
|
||||
|
||||
@@ -53,6 +53,111 @@ def my_simple_tool(question: str) -> str:
|
||||
return "Tool output"
|
||||
```
|
||||
|
||||
### Best Practice: Define Typed Outputs
|
||||
|
||||
When a tool returns structured data, define a Pydantic output model. This helps the agent read the result as clear fields instead of guessing from plain text.
|
||||
|
||||
Typed outputs are useful for results with stable fields, such as IDs, status values, scores, prices, or lists. Plain strings are still fine for short prose results.
|
||||
|
||||
Direct Python calls still receive the value your tool returns. When an agent uses a typed tool, CrewAI sends the agent JSON based on the output model.
|
||||
|
||||
#### Return a Pydantic Model
|
||||
|
||||
CrewAI infers the output schema when your `BaseTool` has a Pydantic return annotation.
|
||||
|
||||
```python Code
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class InventoryResult(BaseModel):
|
||||
sku: str = Field(description="The product SKU.")
|
||||
quantity: int = Field(description="Units available.")
|
||||
needs_reorder: bool = Field(description="Whether the item should be reordered.")
|
||||
|
||||
class InventoryTool(BaseTool):
|
||||
name: str = "Inventory Check"
|
||||
description: str = "Check current stock for a product SKU."
|
||||
|
||||
def _run(self, sku: str) -> InventoryResult:
|
||||
quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0)
|
||||
return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5)
|
||||
|
||||
tool = InventoryTool()
|
||||
result = tool.run(sku="SKU-123")
|
||||
|
||||
# Direct Python calls receive the raw Pydantic object.
|
||||
print(result.quantity)
|
||||
```
|
||||
|
||||
When an agent calls `InventoryTool`, it receives JSON like this:
|
||||
|
||||
```json
|
||||
{"sku":"SKU-123","quantity":14,"needs_reorder":false}
|
||||
```
|
||||
|
||||
#### Use `output_schema` with Dictionary Results
|
||||
|
||||
If your tool returns a dictionary, set `output_schema` explicitly. You can do this on a `BaseTool` subclass or with the `@tool` decorator:
|
||||
|
||||
```python Code
|
||||
from crewai.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ProductResult(BaseModel):
|
||||
sku: str = Field(description="The product SKU.")
|
||||
name: str = Field(description="The product name.")
|
||||
in_stock: bool = Field(description="Whether the product is available.")
|
||||
|
||||
@tool("Product Lookup", output_schema=ProductResult)
|
||||
def product_lookup(sku: str) -> dict[str, object]:
|
||||
"""Look up product availability by SKU."""
|
||||
catalog = {
|
||||
"SKU-123": ("Noise-canceling headset", True),
|
||||
"SKU-456": ("USB-C dock", False),
|
||||
}
|
||||
name, in_stock = catalog.get(sku, ("Unknown product", False))
|
||||
return {
|
||||
"sku": sku,
|
||||
"name": name,
|
||||
"in_stock": in_stock,
|
||||
}
|
||||
```
|
||||
|
||||
#### Customize the Text Sent to the Agent
|
||||
|
||||
By default, typed tool outputs are sent to the agent as JSON. If the agent should receive a short summary instead, subclass `BaseTool` and override `format_output_for_agent`.
|
||||
|
||||
```python Code
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class InventoryResult(BaseModel):
|
||||
sku: str = Field(description="The product SKU.")
|
||||
quantity: int = Field(description="Units available.")
|
||||
needs_reorder: bool = Field(description="Whether the item should be reordered.")
|
||||
|
||||
class InventoryTool(BaseTool):
|
||||
name: str = "Inventory Check"
|
||||
description: str = "Check current stock for a product SKU."
|
||||
|
||||
def _run(self, sku: str) -> InventoryResult:
|
||||
quantity = {"SKU-123": 14, "SKU-456": 0}.get(sku, 0)
|
||||
return InventoryResult(sku=sku, quantity=quantity, needs_reorder=quantity < 5)
|
||||
|
||||
def format_output_for_agent(self, raw_result: object) -> str:
|
||||
result = InventoryResult.model_validate(raw_result)
|
||||
status = "reorder needed" if result.needs_reorder else "stock is healthy"
|
||||
return f"{result.sku}: {result.quantity} units. {status}."
|
||||
|
||||
tool = InventoryTool()
|
||||
result = tool.run(sku="SKU-123")
|
||||
|
||||
# Direct Python calls receive the raw Pydantic object.
|
||||
print(result.quantity)
|
||||
```
|
||||
|
||||
The override only changes what the agent sees. Direct calls to `tool.run(...)` still return the normal Python value.
|
||||
|
||||
### Defining a Cache Function for the Tool
|
||||
|
||||
To optimize tool performance with caching, define custom caching strategies using the `cache_function` attribute.
|
||||
|
||||
@@ -195,9 +195,12 @@ class ToolCallHookContext:
|
||||
agent: Agent | None # Agent executing
|
||||
task: Task | None # Current task
|
||||
crew: Crew | None # Crew instance
|
||||
tool_result: str | None # Tool result (after hooks)
|
||||
tool_result: str | None # Agent-facing result string (after hooks)
|
||||
raw_tool_result: Any | None # Raw Python result (after hooks)
|
||||
```
|
||||
|
||||
For typed tool outputs, `tool_result` is the string the agent sees. By default, this is JSON. If the tool uses custom formatting, it can be Markdown or another string. `raw_tool_result` is the original Python value returned by the tool.
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Safety and Validation
|
||||
|
||||
@@ -60,9 +60,12 @@ class ToolCallHookContext:
|
||||
agent: Agent | BaseAgent | None # Agent executing the tool
|
||||
task: Task | None # Current task
|
||||
crew: Crew | None # Crew instance
|
||||
tool_result: str | None # Tool result (after hooks only)
|
||||
tool_result: str | None # Agent-facing result string (after hooks only)
|
||||
raw_tool_result: Any | None # Raw Python result (after hooks only)
|
||||
```
|
||||
|
||||
For typed tool outputs, `tool_result` is the string the agent sees. By default, this is JSON. If the tool uses custom formatting, it can be Markdown or another string. Use `raw_tool_result` when your hook needs the typed object or dictionary.
|
||||
|
||||
### Modifying Tool Inputs
|
||||
|
||||
**Important:** Always modify tool inputs in-place:
|
||||
|
||||
@@ -63,11 +63,12 @@ print(crew.kickoff())
|
||||
|
||||
## Configuration Options
|
||||
|
||||
The `TavilyResearchTool` accepts the following arguments — all can be set on the tool instance (defaults for every call) or per-call via the agent's tool input:
|
||||
The `TavilyResearchTool` accepts the following arguments. Set `model`, `tavily_output_schema`, `stream`, and `citation_format` on the tool instance as defaults; pass `input`, `model`, `output_schema`, `stream`, and `citation_format` per call via the agent's tool input:
|
||||
|
||||
- `input` (str): **Required.** The research task or question to investigate.
|
||||
- `model` (Literal["mini", "pro", "auto"]): The Tavily research model. `"auto"` lets Tavily pick; `"mini"` is faster/cheaper; `"pro"` is the most capable. Defaults to `"auto"`.
|
||||
- `output_schema` (dict | None): Optional JSON Schema that structures the research output. Useful when you want strictly typed results.
|
||||
- `output_schema` (dict | None): Optional per-call JSON Schema that structures the research output. Useful when you want strictly typed results.
|
||||
- `tavily_output_schema` (dict | None): Optional default JSON Schema for Tavily research output when the tool is instantiated.
|
||||
- `stream` (bool): When `True`, the tool returns an iterator of SSE chunks emitting research progress and the final result instead of a single string. Defaults to `False`.
|
||||
- `citation_format` (Literal["numbered", "mla", "apa", "chicago"]): Citation format for the report. Defaults to `"numbered"`.
|
||||
|
||||
@@ -97,7 +98,7 @@ for chunk in tavily_tool.run(input="Summarize recent advances in retrieval-augme
|
||||
|
||||
### Structured output via JSON Schema
|
||||
|
||||
Pass an `output_schema` when you need a typed result instead of a free-form report:
|
||||
Pass an `output_schema` per call, or set `tavily_output_schema` on the tool instance, when you need a typed result instead of a free-form report:
|
||||
|
||||
```python
|
||||
output_schema = {
|
||||
@@ -110,7 +111,7 @@ output_schema = {
|
||||
"required": ["summary", "key_points", "sources"],
|
||||
}
|
||||
|
||||
tavily_tool = TavilyResearchTool(output_schema=output_schema)
|
||||
tavily_tool = TavilyResearchTool(tavily_output_schema=output_schema)
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
@@ -149,55 +146,3 @@ class TestSettings(unittest.TestCase):
|
||||
|
||||
settings = Settings(config_path=self.config_path)
|
||||
self.assertIsNone(settings.tool_repository_username)
|
||||
|
||||
|
||||
class TestSettingsFilePermissions(unittest.TestCase):
|
||||
"""Regression tests: credentials in settings.json must not be world-readable."""
|
||||
|
||||
def setUp(self):
|
||||
self.test_dir = Path(tempfile.mkdtemp())
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir, ignore_errors=True)
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "POSIX permission semantics")
|
||||
def test_dump_writes_owner_only_file(self):
|
||||
config_path = self.test_dir / "settings.json"
|
||||
old_umask = os.umask(0o022)
|
||||
try:
|
||||
settings = Settings(
|
||||
config_path=config_path, tool_repository_password="hunter2"
|
||||
)
|
||||
settings.dump()
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
mode = stat.S_IMODE(config_path.stat().st_mode)
|
||||
self.assertEqual(mode, 0o600, f"expected 0o600, got {oct(mode)}")
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "POSIX permission semantics")
|
||||
def test_dedicated_config_dir_is_owner_only(self):
|
||||
config_path = self.test_dir / "crewai" / "settings.json"
|
||||
old_umask = os.umask(0o022)
|
||||
try:
|
||||
Settings(config_path=config_path, tool_repository_username="u")
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
mode = stat.S_IMODE(config_path.parent.stat().st_mode)
|
||||
self.assertEqual(mode, 0o700, f"expected 0o700, got {oct(mode)}")
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "POSIX permission semantics")
|
||||
def test_shared_fallback_dir_is_not_chmodded(self):
|
||||
"""The system temp dir (a fallback parent) must never be globally chmod'd."""
|
||||
from crewai_core.settings import _ensure_dir_mode
|
||||
|
||||
tmp_root = Path(tempfile.gettempdir())
|
||||
before = stat.S_IMODE(tmp_root.stat().st_mode)
|
||||
_ensure_dir_mode(tmp_root)
|
||||
after = stat.S_IMODE(tmp_root.stat().st_mode)
|
||||
self.assertEqual(before, after)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
"""Tests for TokenManager with atomic file operations."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
@@ -288,50 +285,5 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
tm._delete_secure_file("nonexistent.txt")
|
||||
|
||||
|
||||
class TestSecureStoragePathPermissions(unittest.TestCase):
|
||||
"""Test that the credential directory is created with restrictive permissions."""
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "POSIX permission semantics")
|
||||
def test_storage_path_is_owner_only(self) -> None:
|
||||
"""The credential directory must be mode 0o700 even under a permissive umask."""
|
||||
with tempfile.TemporaryDirectory() as base:
|
||||
old_umask = os.umask(0o022)
|
||||
try:
|
||||
with (
|
||||
patch("crewai_core.token_manager.sys.platform", "linux"),
|
||||
patch(
|
||||
"crewai_core.token_manager.os.path.expanduser",
|
||||
return_value=base,
|
||||
),
|
||||
):
|
||||
storage_path = TokenManager._get_secure_storage_path()
|
||||
finally:
|
||||
os.umask(old_umask)
|
||||
|
||||
self.assertTrue(storage_path.is_dir())
|
||||
mode = stat.S_IMODE(storage_path.stat().st_mode)
|
||||
self.assertEqual(mode, 0o700, f"expected 0o700, got {oct(mode)}")
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32", "POSIX permission semantics")
|
||||
def test_existing_loose_dir_is_tightened(self) -> None:
|
||||
"""A pre-existing world-traversable directory is corrected to 0o700."""
|
||||
with tempfile.TemporaryDirectory() as base:
|
||||
loose = Path(base) / "crewai" / "credentials"
|
||||
loose.mkdir(parents=True)
|
||||
loose.chmod(0o755)
|
||||
|
||||
with (
|
||||
patch("crewai_core.token_manager.sys.platform", "linux"),
|
||||
patch(
|
||||
"crewai_core.token_manager.os.path.expanduser",
|
||||
return_value=base,
|
||||
),
|
||||
):
|
||||
storage_path = TokenManager._get_secure_storage_path()
|
||||
|
||||
mode = stat.S_IMODE(storage_path.stat().st_mode)
|
||||
self.assertEqual(mode, 0o700, f"expected 0o700, got {oct(mode)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from logging import getLogger
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from typing import Any
|
||||
@@ -26,41 +25,6 @@ logger = getLogger(__name__)
|
||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||
|
||||
|
||||
def _ensure_dir_mode(directory: Path) -> None:
|
||||
"""Tighten a dedicated config directory to 0o700.
|
||||
|
||||
Skips directories shared with other users or content (the system temp dir
|
||||
and the current working directory), which are used as best-effort fallbacks
|
||||
by :func:`get_writable_config_path` and must not be globally chmod'd. Secret
|
||||
files written there are still protected by their own 0o600 mode.
|
||||
"""
|
||||
try:
|
||||
shared = {Path(tempfile.gettempdir()).resolve(), Path.cwd().resolve()}
|
||||
if directory.resolve() in shared:
|
||||
return
|
||||
directory.chmod(0o700)
|
||||
except OSError as e:
|
||||
logger.debug(
|
||||
"Could not enforce 0o700 on config directory %s (best-effort): %s",
|
||||
directory,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
def _write_secure_json(path: Path, data: dict[str, Any]) -> None:
|
||||
"""Atomically write ``data`` as JSON to ``path`` with owner-only (0o600) mode."""
|
||||
fd, tmp = tempfile.mkstemp(dir=path.parent, prefix=f".{path.name}.")
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
os.chmod(tmp, 0o600)
|
||||
os.replace(tmp, path)
|
||||
except BaseException:
|
||||
if os.path.exists(tmp):
|
||||
os.unlink(tmp)
|
||||
raise
|
||||
|
||||
|
||||
def get_writable_config_path() -> Path | None:
|
||||
"""Find a writable location for the config file with fallback options.
|
||||
|
||||
@@ -79,7 +43,6 @@ def get_writable_config_path() -> Path | None:
|
||||
for config_path in fallback_paths:
|
||||
try:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_ensure_dir_mode(config_path.parent)
|
||||
test_file = config_path.parent / ".crewai_write_test"
|
||||
try:
|
||||
test_file.write_text("test")
|
||||
@@ -190,7 +153,6 @@ class Settings(BaseModel):
|
||||
|
||||
try:
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
_ensure_dir_mode(config_path.parent)
|
||||
except Exception:
|
||||
merged_data = {**data}
|
||||
super().__init__(config_path=Path("/dev/null"), **merged_data)
|
||||
@@ -232,7 +194,8 @@ class Settings(BaseModel):
|
||||
existing_data = {}
|
||||
|
||||
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
||||
_write_secure_json(self.config_path, updated_data)
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(updated_data, f, indent=4)
|
||||
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
@@ -95,14 +95,6 @@ class TokenManager:
|
||||
storage_path = Path(base_path) / app_name
|
||||
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
# Enforce the documented 0o700 mode: mkdir is subject to umask and does
|
||||
# not adjust the mode of a pre-existing directory, so chmod explicitly.
|
||||
try:
|
||||
storage_path.chmod(0o700)
|
||||
except OSError:
|
||||
# Best-effort permission hardening only: some platforms/filesystems
|
||||
# may reject chmod here, and token operations should still proceed.
|
||||
pass
|
||||
|
||||
return storage_path
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class TavilyResearchTool(BaseTool):
|
||||
default="auto",
|
||||
description="Default model used for new Tavily research tasks.",
|
||||
)
|
||||
output_schema: dict[str, Any] | None = Field(
|
||||
tavily_output_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Default JSON Schema used to structure research output.",
|
||||
)
|
||||
@@ -87,6 +87,10 @@ class TavilyResearchTool(BaseTool):
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
output_schema = kwargs.get("output_schema")
|
||||
if isinstance(output_schema, dict) and "tavily_output_schema" not in kwargs:
|
||||
kwargs["tavily_output_schema"] = kwargs.pop("output_schema")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if TAVILY_AVAILABLE:
|
||||
api_key = os.getenv("TAVILY_API_KEY")
|
||||
@@ -152,7 +156,7 @@ class TavilyResearchTool(BaseTool):
|
||||
result = self._client.research(
|
||||
input=input,
|
||||
model=self.model if model is None else model,
|
||||
output_schema=self.output_schema
|
||||
output_schema=self.tavily_output_schema
|
||||
if output_schema is None
|
||||
else output_schema,
|
||||
stream=use_stream,
|
||||
@@ -185,7 +189,7 @@ class TavilyResearchTool(BaseTool):
|
||||
result = await self._async_client.research(
|
||||
input=input,
|
||||
model=self.model if model is None else model,
|
||||
output_schema=self.output_schema
|
||||
output_schema=self.tavily_output_schema
|
||||
if output_schema is None
|
||||
else output_schema,
|
||||
stream=use_stream,
|
||||
|
||||
@@ -25386,7 +25386,13 @@
|
||||
"title": "Model",
|
||||
"type": "string"
|
||||
},
|
||||
"output_schema": {
|
||||
"stream": {
|
||||
"default": false,
|
||||
"description": "Whether new Tavily research tasks should stream responses by default.",
|
||||
"title": "Stream",
|
||||
"type": "boolean"
|
||||
},
|
||||
"tavily_output_schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"additionalProperties": true,
|
||||
@@ -25398,13 +25404,7 @@
|
||||
],
|
||||
"default": null,
|
||||
"description": "Default JSON Schema used to structure research output.",
|
||||
"title": "Output Schema"
|
||||
},
|
||||
"stream": {
|
||||
"default": false,
|
||||
"description": "Whether new Tavily research tasks should stream responses by default.",
|
||||
"title": "Stream",
|
||||
"type": "boolean"
|
||||
"title": "Tavily Output Schema"
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
|
||||
@@ -57,6 +57,7 @@ from crewai.utilities.agent_utils import (
|
||||
convert_tools_to_openai_schema,
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
format_native_tool_output_for_agent,
|
||||
get_llm_response,
|
||||
handle_agent_action_core,
|
||||
handle_context_length,
|
||||
@@ -907,19 +908,31 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
):
|
||||
max_usage_reached = True
|
||||
|
||||
structured_tool: CrewStructuredTool | None = None
|
||||
if original_tool is not None:
|
||||
for structured in self.tools or []:
|
||||
if getattr(structured, "_original_tool", None) is original_tool:
|
||||
structured_tool = structured
|
||||
break
|
||||
if structured_tool is None:
|
||||
for structured in self.tools or []:
|
||||
if sanitize_tool_name(structured.name) == func_name:
|
||||
structured_tool = structured
|
||||
break
|
||||
|
||||
output_tool = original_tool or structured_tool
|
||||
|
||||
from_cache = False
|
||||
result: str = "Tool not found"
|
||||
raw_tool_result: Any = result
|
||||
input_str = json.dumps(args_dict) if args_dict else ""
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
if self.tools_handler and self.tools_handler.cache and output_tool is not None:
|
||||
cached_result = self.tools_handler.cache.read(
|
||||
tool=func_name, input=input_str
|
||||
)
|
||||
if cached_result is not None:
|
||||
result = (
|
||||
str(cached_result)
|
||||
if not isinstance(cached_result, str)
|
||||
else cached_result
|
||||
)
|
||||
raw_tool_result = cached_result
|
||||
result = format_native_tool_output_for_agent(output_tool, cached_result)
|
||||
from_cache = True
|
||||
|
||||
agent_key = getattr(self.agent, "key", "unknown") if self.agent else "unknown"
|
||||
@@ -938,18 +951,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
|
||||
track_delegation_if_needed(func_name, args_dict or {}, self.task)
|
||||
|
||||
structured_tool: CrewStructuredTool | None = None
|
||||
if original_tool is not None:
|
||||
for structured in self.tools or []:
|
||||
if getattr(structured, "_original_tool", None) is original_tool:
|
||||
structured_tool = structured
|
||||
break
|
||||
if structured_tool is None:
|
||||
for structured in self.tools or []:
|
||||
if sanitize_tool_name(structured.name) == func_name:
|
||||
structured_tool = structured
|
||||
break
|
||||
|
||||
hook_blocked = False
|
||||
before_hook_context = ToolCallHookContext(
|
||||
tool_name=func_name,
|
||||
@@ -975,11 +976,18 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
|
||||
if hook_blocked:
|
||||
result = f"Tool execution blocked by hook. Tool: {func_name}"
|
||||
raw_tool_result = result
|
||||
elif max_usage_reached and original_tool:
|
||||
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
|
||||
elif not from_cache and func_name in available_functions:
|
||||
raw_tool_result = result
|
||||
elif (
|
||||
not from_cache
|
||||
and func_name in available_functions
|
||||
and output_tool is not None
|
||||
):
|
||||
try:
|
||||
raw_result = available_functions[func_name](**(args_dict or {}))
|
||||
raw_tool_result = raw_result
|
||||
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
should_cache = True
|
||||
@@ -996,11 +1004,10 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
tool=func_name, input=input_str, output=raw_result
|
||||
)
|
||||
|
||||
result = (
|
||||
str(raw_result) if not isinstance(raw_result, str) else raw_result
|
||||
)
|
||||
result = format_native_tool_output_for_agent(output_tool, raw_result)
|
||||
except Exception as e:
|
||||
result = f"Error executing tool: {e}"
|
||||
raw_tool_result = result
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
crewai_event_bus.emit(
|
||||
@@ -1024,6 +1031,7 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
task=self.task,
|
||||
crew=self.crew,
|
||||
tool_result=result,
|
||||
raw_tool_result=raw_tool_result,
|
||||
)
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
try:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -25,14 +26,14 @@ class ToolsHandler(BaseModel):
|
||||
def on_tool_use(
|
||||
self,
|
||||
calling: ToolCalling | InstructorToolCalling,
|
||||
output: str,
|
||||
output: Any,
|
||||
should_cache: bool = True,
|
||||
) -> None:
|
||||
"""Run when tool ends running.
|
||||
|
||||
Args:
|
||||
calling: The tool calling instance.
|
||||
output: The output from the tool execution.
|
||||
output: The raw output from the tool execution.
|
||||
should_cache: Whether to cache the tool output.
|
||||
"""
|
||||
self.last_used_tool = calling
|
||||
|
||||
@@ -80,6 +80,7 @@ from crewai.utilities.agent_utils import (
|
||||
enforce_rpm_limit,
|
||||
extract_tool_call_info,
|
||||
format_message_for_llm,
|
||||
format_native_tool_output_for_agent,
|
||||
get_llm_response,
|
||||
handle_agent_action_core,
|
||||
handle_context_length,
|
||||
@@ -1905,19 +1906,32 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
):
|
||||
max_usage_reached = True
|
||||
|
||||
structured_tool: CrewStructuredTool | None = None
|
||||
if original_tool is not None:
|
||||
for structured in self.tools or []:
|
||||
if getattr(structured, "_original_tool", None) is original_tool:
|
||||
structured_tool = structured
|
||||
break
|
||||
if structured_tool is None:
|
||||
for structured in self.tools or []:
|
||||
if sanitize_tool_name(structured.name) == func_name:
|
||||
structured_tool = structured
|
||||
break
|
||||
|
||||
output_tool = original_tool or structured_tool
|
||||
|
||||
# Check cache before executing
|
||||
from_cache = False
|
||||
result = "Tool not found"
|
||||
raw_tool_result: Any = result
|
||||
input_str = json.dumps(args_dict) if args_dict else ""
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
if self.tools_handler and self.tools_handler.cache and output_tool is not None:
|
||||
cached_result = self.tools_handler.cache.read(
|
||||
tool=func_name, input=input_str
|
||||
)
|
||||
if cached_result is not None:
|
||||
result = (
|
||||
str(cached_result)
|
||||
if not isinstance(cached_result, str)
|
||||
else cached_result
|
||||
)
|
||||
raw_tool_result = cached_result
|
||||
result = format_native_tool_output_for_agent(output_tool, cached_result)
|
||||
from_cache = True
|
||||
|
||||
# Emit tool usage started event
|
||||
@@ -1936,18 +1950,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
|
||||
track_delegation_if_needed(func_name, args_dict, self.task)
|
||||
|
||||
structured_tool: CrewStructuredTool | None = None
|
||||
if original_tool is not None:
|
||||
for structured in self.tools or []:
|
||||
if getattr(structured, "_original_tool", None) is original_tool:
|
||||
structured_tool = structured
|
||||
break
|
||||
if structured_tool is None:
|
||||
for structured in self.tools or []:
|
||||
if sanitize_tool_name(structured.name) == func_name:
|
||||
structured_tool = structured
|
||||
break
|
||||
|
||||
hook_blocked = False
|
||||
before_hook_context = ToolCallHookContext(
|
||||
tool_name=func_name,
|
||||
@@ -1973,12 +1975,13 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
|
||||
if hook_blocked:
|
||||
result = f"Tool execution blocked by hook. Tool: {func_name}"
|
||||
elif not from_cache and not max_usage_reached:
|
||||
result = "Tool not found"
|
||||
raw_tool_result = result
|
||||
elif not from_cache and not max_usage_reached and output_tool is not None:
|
||||
if func_name in self._available_functions:
|
||||
try:
|
||||
tool_func = self._available_functions[func_name]
|
||||
raw_result = tool_func(**args_dict)
|
||||
raw_tool_result = raw_result
|
||||
|
||||
# Add to cache after successful execution (before string conversion)
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
@@ -1992,14 +1995,12 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
tool=func_name, input=input_str, output=raw_result
|
||||
)
|
||||
|
||||
# Convert to string for message
|
||||
result = (
|
||||
str(raw_result)
|
||||
if not isinstance(raw_result, str)
|
||||
else raw_result
|
||||
result = format_native_tool_output_for_agent(
|
||||
output_tool, raw_result
|
||||
)
|
||||
except Exception as e:
|
||||
result = f"Error executing tool: {e}"
|
||||
raw_tool_result = result
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
# Emit tool usage error event
|
||||
@@ -2021,6 +2022,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
result = f"Tool '{func_name}' has reached its usage limit of {original_tool.max_usage_count} times and cannot be used anymore."
|
||||
else:
|
||||
result = f"Tool '{func_name}' has reached its maximum usage limit and cannot be used anymore."
|
||||
raw_tool_result = result
|
||||
|
||||
# Execute after_tool_call hooks (even if blocked, to allow logging/monitoring)
|
||||
after_hook_context = ToolCallHookContext(
|
||||
@@ -2031,6 +2033,7 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
task=self.task,
|
||||
crew=self.crew,
|
||||
tool_result=result,
|
||||
raw_tool_result=raw_tool_result,
|
||||
)
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
try:
|
||||
|
||||
@@ -40,6 +40,8 @@ class ToolCallHookContext:
|
||||
crew: Crew instance (may be None)
|
||||
tool_result: Tool execution result (only set for after_tool_call hooks).
|
||||
Can be modified by returning a new string from after_tool_call hook.
|
||||
raw_tool_result: Raw Python tool execution result (only set for
|
||||
after_tool_call hooks). This is not modified by after hooks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -51,6 +53,7 @@ class ToolCallHookContext:
|
||||
task: Task | None = None,
|
||||
crew: Crew | None = None,
|
||||
tool_result: str | None = None,
|
||||
raw_tool_result: Any | None = None,
|
||||
) -> None:
|
||||
"""Initialize tool call hook context.
|
||||
|
||||
@@ -62,6 +65,7 @@ class ToolCallHookContext:
|
||||
task: Optional current task
|
||||
crew: Optional crew instance
|
||||
tool_result: Optional tool result (for after hooks)
|
||||
raw_tool_result: Optional raw tool result (for after hooks)
|
||||
"""
|
||||
self.tool_name = tool_name
|
||||
self.tool_input = tool_input
|
||||
@@ -70,6 +74,7 @@ class ToolCallHookContext:
|
||||
self.task = task
|
||||
self.crew = crew
|
||||
self.tool_result = tool_result
|
||||
self.raw_tool_result = raw_tool_result
|
||||
|
||||
def request_human_input(
|
||||
self,
|
||||
|
||||
@@ -33,6 +33,8 @@ from typing_extensions import TypeIs
|
||||
from crewai.tools.structured_tool import (
|
||||
CrewStructuredTool,
|
||||
_deserialize_schema,
|
||||
_format_tool_output_for_agent,
|
||||
_infer_output_schema_from_callable,
|
||||
_serialize_schema,
|
||||
build_schema_hint,
|
||||
)
|
||||
@@ -149,6 +151,11 @@ class BaseTool(BaseModel, ABC):
|
||||
validate_default=True,
|
||||
description="The schema for the arguments that the tool accepts.",
|
||||
)
|
||||
output_schema: type[PydanticBaseModel] | None = Field(
|
||||
default=None,
|
||||
validate_default=True,
|
||||
description="The schema for the output that the tool returns.",
|
||||
)
|
||||
|
||||
@field_serializer("args_schema", when_used="json")
|
||||
def _serialize_args_schema(
|
||||
@@ -156,6 +163,12 @@ class BaseTool(BaseModel, ABC):
|
||||
) -> dict[str, Any] | None:
|
||||
return _serialize_schema(schema)
|
||||
|
||||
@field_serializer("output_schema", when_used="json")
|
||||
def _serialize_output_schema(
|
||||
self, schema: type[PydanticBaseModel] | None
|
||||
) -> dict[str, Any] | None:
|
||||
return _serialize_schema(schema)
|
||||
|
||||
description_updated: bool = Field(
|
||||
default=False, description="Flag to check if the description has been updated."
|
||||
)
|
||||
@@ -233,6 +246,17 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
return create_model(f"{cls.__name__}Schema", **fields)
|
||||
|
||||
@field_validator("output_schema", mode="before")
|
||||
@classmethod
|
||||
def _default_output_schema(
|
||||
cls, v: type[PydanticBaseModel] | dict[str, Any] | None
|
||||
) -> type[PydanticBaseModel] | None:
|
||||
if isinstance(v, dict):
|
||||
return _deserialize_schema(v)
|
||||
if v is not None:
|
||||
return v
|
||||
return _infer_output_schema_from_callable(cls._run)
|
||||
|
||||
@field_validator("max_usage_count", mode="before")
|
||||
@classmethod
|
||||
def validate_max_usage_count(cls, v: int | None) -> int | None:
|
||||
@@ -340,6 +364,10 @@ class BaseTool(BaseModel, ABC):
|
||||
"Override _arun for async support or use run() for sync execution."
|
||||
)
|
||||
|
||||
def format_output_for_agent(self, raw_result: Any) -> str:
|
||||
"""Format a raw tool result into the string representation sent to an agent."""
|
||||
return _format_tool_output_for_agent(self, raw_result)
|
||||
|
||||
def reset_usage_count(self) -> None:
|
||||
"""Reset the current usage count to zero."""
|
||||
self.current_usage_count = 0
|
||||
@@ -369,6 +397,7 @@ class BaseTool(BaseModel, ABC):
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
args_schema=self.args_schema,
|
||||
output_schema=self.output_schema,
|
||||
func=self._run,
|
||||
result_as_answer=self.result_as_answer,
|
||||
max_usage_count=self.max_usage_count,
|
||||
@@ -390,6 +419,9 @@ class BaseTool(BaseModel, ABC):
|
||||
raise ValueError("The provided tool must have a callable 'func' attribute.")
|
||||
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
output_schema = getattr(tool, "output_schema", None)
|
||||
if output_schema is None:
|
||||
output_schema = _infer_output_schema_from_callable(tool.func)
|
||||
|
||||
if args_schema is None:
|
||||
func_signature = signature(tool.func)
|
||||
@@ -420,6 +452,7 @@ class BaseTool(BaseModel, ABC):
|
||||
description=getattr(tool, "description", ""),
|
||||
func=tool.func,
|
||||
args_schema=args_schema,
|
||||
output_schema=output_schema,
|
||||
)
|
||||
|
||||
def _set_args_schema(self) -> None:
|
||||
@@ -568,6 +601,9 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
raise ValueError("The provided tool must have a callable 'func' attribute.")
|
||||
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
output_schema = getattr(tool, "output_schema", None)
|
||||
if output_schema is None:
|
||||
output_schema = _infer_output_schema_from_callable(tool.func)
|
||||
|
||||
if args_schema is None:
|
||||
func_signature = signature(tool.func)
|
||||
@@ -598,6 +634,7 @@ class Tool(BaseTool, Generic[P, R]):
|
||||
description=getattr(tool, "description", ""),
|
||||
func=tool.func,
|
||||
args_schema=args_schema,
|
||||
output_schema=output_schema,
|
||||
)
|
||||
|
||||
|
||||
@@ -621,6 +658,7 @@ def tool(
|
||||
name: str,
|
||||
/,
|
||||
*,
|
||||
output_schema: type[BaseModel] | None = ...,
|
||||
result_as_answer: bool = ...,
|
||||
max_usage_count: int | None = ...,
|
||||
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
|
||||
@@ -629,6 +667,7 @@ def tool(
|
||||
@overload
|
||||
def tool(
|
||||
*,
|
||||
output_schema: type[BaseModel] | None = ...,
|
||||
result_as_answer: bool = ...,
|
||||
max_usage_count: int | None = ...,
|
||||
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
|
||||
@@ -636,6 +675,7 @@ def tool(
|
||||
|
||||
def tool(
|
||||
*args: Callable[P2, R2] | str,
|
||||
output_schema: type[BaseModel] | None = None,
|
||||
result_as_answer: bool = False,
|
||||
max_usage_count: int | None = None,
|
||||
) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]:
|
||||
@@ -649,6 +689,7 @@ def tool(
|
||||
Args:
|
||||
*args: Either the function to decorate or a custom tool name.
|
||||
result_as_answer: If True, the tool result becomes the final agent answer.
|
||||
output_schema: Optional schema for the output that the tool returns.
|
||||
max_usage_count: Maximum times this tool can be used. None means unlimited.
|
||||
|
||||
Returns:
|
||||
@@ -690,12 +731,16 @@ def tool(
|
||||
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = create_model(class_name, **fields)
|
||||
resolved_output_schema = (
|
||||
output_schema or _infer_output_schema_from_callable(f)
|
||||
)
|
||||
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
description=f.__doc__,
|
||||
func=f,
|
||||
args_schema=args_schema,
|
||||
output_schema=resolved_output_schema,
|
||||
result_as_answer=result_as_answer,
|
||||
max_usage_count=max_usage_count,
|
||||
current_usage_count=0,
|
||||
|
||||
@@ -5,7 +5,8 @@ from collections.abc import Callable
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING, Annotated, Any, get_type_hints
|
||||
from typing import TYPE_CHECKING, Annotated, Any, cast, get_type_hints
|
||||
import warnings
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -36,6 +37,52 @@ def _deserialize_schema(v: Any) -> type[BaseModel] | None:
|
||||
return None
|
||||
|
||||
|
||||
def _infer_output_schema_from_callable(
|
||||
func: Callable[..., Any],
|
||||
) -> type[BaseModel] | None:
|
||||
try:
|
||||
return_annotation = get_type_hints(func).get("return", inspect.Signature.empty)
|
||||
except Exception:
|
||||
return_annotation = inspect.signature(func).return_annotation
|
||||
|
||||
if isinstance(return_annotation, type) and issubclass(return_annotation, BaseModel):
|
||||
return return_annotation
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _format_tool_output_for_agent(tool: Any, raw_result: Any) -> str:
|
||||
original_tool = getattr(tool, "_original_tool", None)
|
||||
if original_tool is not None:
|
||||
return cast(str, original_tool.format_output_for_agent(raw_result))
|
||||
|
||||
output_schema = getattr(tool, "output_schema", None)
|
||||
if output_schema is None:
|
||||
return str(raw_result)
|
||||
|
||||
try:
|
||||
validation_input = raw_result
|
||||
if isinstance(raw_result, BaseModel) and not isinstance(
|
||||
raw_result, output_schema
|
||||
):
|
||||
validation_input = raw_result.model_dump()
|
||||
|
||||
validated = output_schema.model_validate(validation_input)
|
||||
return cast(str, validated.model_dump_json())
|
||||
except Exception as exc:
|
||||
warnings.warn(
|
||||
(
|
||||
f"Failed to validate or serialize output from tool "
|
||||
f"'{getattr(tool, 'name', '<unknown>')}' using output_schema "
|
||||
f"'{output_schema.__name__}': {exc.__class__.__name__}. "
|
||||
"Falling back to str(raw_result)."
|
||||
),
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return str(raw_result)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
@@ -81,6 +128,11 @@ class CrewStructuredTool(BaseModel):
|
||||
BeforeValidator(_deserialize_schema),
|
||||
PlainSerializer(_serialize_schema),
|
||||
] = Field(default=None)
|
||||
output_schema: Annotated[
|
||||
type[BaseModel] | None,
|
||||
BeforeValidator(_deserialize_schema),
|
||||
PlainSerializer(_serialize_schema),
|
||||
] = Field(default=None)
|
||||
func: Any = Field(default=None, exclude=True)
|
||||
result_as_answer: bool = Field(default=False)
|
||||
max_usage_count: int | None = Field(default=None)
|
||||
@@ -103,6 +155,7 @@ class CrewStructuredTool(BaseModel):
|
||||
description: str | None = None,
|
||||
return_direct: bool = False,
|
||||
args_schema: type[BaseModel] | None = None,
|
||||
output_schema: type[BaseModel] | None = None,
|
||||
infer_schema: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> CrewStructuredTool:
|
||||
@@ -114,6 +167,7 @@ class CrewStructuredTool(BaseModel):
|
||||
description: The description of the tool. Defaults to the function docstring
|
||||
return_direct: Whether to return the output directly
|
||||
args_schema: Optional schema for the function arguments
|
||||
output_schema: Optional schema for the function output
|
||||
infer_schema: Whether to infer the schema from the function signature
|
||||
**kwargs: Additional arguments to pass to the tool
|
||||
|
||||
@@ -149,10 +203,16 @@ class CrewStructuredTool(BaseModel):
|
||||
name=name,
|
||||
description=description,
|
||||
args_schema=schema,
|
||||
output_schema=output_schema or _infer_output_schema_from_callable(func),
|
||||
func=func,
|
||||
result_as_answer=return_direct,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def format_output_for_agent(self, raw_result: Any) -> str:
|
||||
"""Format a raw tool result into the string representation sent to an agent."""
|
||||
return _format_tool_output_for_agent(self, raw_result)
|
||||
|
||||
@staticmethod
|
||||
def _create_schema_from_function(
|
||||
name: str,
|
||||
|
||||
@@ -106,6 +106,7 @@ class ToolUsage:
|
||||
self.action = action
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.fingerprint_context = fingerprint_context or {}
|
||||
self.last_raw_result: Any = None
|
||||
|
||||
if (
|
||||
self.function_calling_llm
|
||||
@@ -231,6 +232,7 @@ class ToolUsage:
|
||||
result = I18N_DEFAULT.errors("task_repeated_usage").format(
|
||||
tool_names=self.tools_names
|
||||
)
|
||||
self.last_raw_result = result
|
||||
self._telemetry.tool_repeated_usage(
|
||||
llm=self.function_calling_llm,
|
||||
tool_name=sanitize_tool_name(tool.name),
|
||||
@@ -298,6 +300,7 @@ class ToolUsage:
|
||||
)
|
||||
if usage_limit_error:
|
||||
result = usage_limit_error
|
||||
self.last_raw_result = result
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
result = self._format_result(result=result)
|
||||
elif result is None:
|
||||
@@ -359,7 +362,10 @@ class ToolUsage:
|
||||
tool_name=sanitize_tool_name(tool.name),
|
||||
attempts=self._run_attempts,
|
||||
)
|
||||
result = self._format_result(result=result)
|
||||
self.last_raw_result = result
|
||||
result = self._format_result(
|
||||
result=tool.format_output_for_agent(result)
|
||||
)
|
||||
data = {
|
||||
"result": result,
|
||||
"tool_name": sanitize_tool_name(tool.name),
|
||||
@@ -421,6 +427,7 @@ class ToolUsage:
|
||||
result = ToolUsageError(
|
||||
f"\n{error_message}.\nMoving on then. {I18N_DEFAULT.slice('format').format(tool_names=self.tools_names)}"
|
||||
).message
|
||||
self.last_raw_result = result
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
if self.agent and self.agent.verbose:
|
||||
@@ -430,7 +437,10 @@ class ToolUsage:
|
||||
self.task.increment_tools_errors()
|
||||
should_retry = True
|
||||
else:
|
||||
result = self._format_result(result=result)
|
||||
self.last_raw_result = result
|
||||
result = self._format_result(
|
||||
result=tool.format_output_for_agent(result)
|
||||
)
|
||||
|
||||
finally:
|
||||
if started_event_emitted and not error_event_emitted:
|
||||
@@ -460,6 +470,7 @@ class ToolUsage:
|
||||
result = I18N_DEFAULT.errors("task_repeated_usage").format(
|
||||
tool_names=self.tools_names
|
||||
)
|
||||
self.last_raw_result = result
|
||||
self._telemetry.tool_repeated_usage(
|
||||
llm=self.function_calling_llm,
|
||||
tool_name=sanitize_tool_name(tool.name),
|
||||
@@ -529,6 +540,7 @@ class ToolUsage:
|
||||
)
|
||||
if usage_limit_error:
|
||||
result = usage_limit_error
|
||||
self.last_raw_result = result
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
result = self._format_result(result=result)
|
||||
elif result is None:
|
||||
@@ -590,7 +602,10 @@ class ToolUsage:
|
||||
tool_name=sanitize_tool_name(tool.name),
|
||||
attempts=self._run_attempts,
|
||||
)
|
||||
result = self._format_result(result=result)
|
||||
self.last_raw_result = result
|
||||
result = self._format_result(
|
||||
result=tool.format_output_for_agent(result)
|
||||
)
|
||||
data = {
|
||||
"result": result,
|
||||
"tool_name": sanitize_tool_name(tool.name),
|
||||
@@ -652,6 +667,7 @@ class ToolUsage:
|
||||
result = ToolUsageError(
|
||||
f"\n{error_message}.\nMoving on then. {I18N_DEFAULT.slice('format').format(tool_names=self.tools_names)}"
|
||||
).message
|
||||
self.last_raw_result = result
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
if self.agent and self.agent.verbose:
|
||||
@@ -661,7 +677,10 @@ class ToolUsage:
|
||||
self.task.increment_tools_errors()
|
||||
should_retry = True
|
||||
else:
|
||||
result = self._format_result(result=result)
|
||||
self.last_raw_result = result
|
||||
result = self._format_result(
|
||||
result=tool.format_output_for_agent(result)
|
||||
)
|
||||
|
||||
finally:
|
||||
if started_event_emitted and not error_event_emitted:
|
||||
|
||||
@@ -1383,6 +1383,19 @@ class NativeToolCallResult:
|
||||
tool_message: LLMMessage = field(default_factory=dict) # type: ignore[assignment]
|
||||
|
||||
|
||||
def format_native_tool_output_for_agent(tool: Any, raw_result: Any) -> str:
|
||||
"""Format native tool output when a tool explicitly defines a formatter."""
|
||||
formatter = inspect.getattr_static(tool, "format_output_for_agent", None)
|
||||
if formatter is None:
|
||||
return str(raw_result)
|
||||
|
||||
runtime_formatter = getattr(tool, "format_output_for_agent", None)
|
||||
if not callable(runtime_formatter):
|
||||
return str(raw_result)
|
||||
|
||||
return str(runtime_formatter(raw_result))
|
||||
|
||||
|
||||
def execute_single_native_tool_call(
|
||||
tool_call: Any,
|
||||
*,
|
||||
@@ -1456,18 +1469,24 @@ def execute_single_native_tool_call(
|
||||
original_tool = tool
|
||||
break
|
||||
|
||||
structured_tool: CrewStructuredTool | None = None
|
||||
for structured in structured_tools or []:
|
||||
if sanitize_tool_name(structured.name) == func_name:
|
||||
structured_tool = structured
|
||||
break
|
||||
|
||||
output_tool = original_tool or structured_tool
|
||||
|
||||
from_cache = False
|
||||
input_str = json.dumps(args_dict) if args_dict else ""
|
||||
result = "Tool not found"
|
||||
raw_tool_result: Any = result
|
||||
|
||||
if tools_handler and tools_handler.cache:
|
||||
if tools_handler and tools_handler.cache and output_tool is not None:
|
||||
cached_result = tools_handler.cache.read(tool=func_name, input=input_str)
|
||||
if cached_result is not None:
|
||||
result = (
|
||||
str(cached_result)
|
||||
if not isinstance(cached_result, str)
|
||||
else cached_result
|
||||
)
|
||||
raw_tool_result = cached_result
|
||||
result = format_native_tool_output_for_agent(output_tool, cached_result)
|
||||
from_cache = True
|
||||
|
||||
started_at = datetime.now()
|
||||
@@ -1486,12 +1505,6 @@ def execute_single_native_tool_call(
|
||||
|
||||
track_delegation_if_needed(func_name, args_dict, task)
|
||||
|
||||
structured_tool: CrewStructuredTool | None = None
|
||||
for structured in structured_tools or []:
|
||||
if sanitize_tool_name(structured.name) == func_name:
|
||||
structured_tool = structured
|
||||
break
|
||||
|
||||
hook_blocked = False
|
||||
before_hook_context = ToolCallHookContext(
|
||||
tool_name=func_name,
|
||||
@@ -1512,11 +1525,13 @@ def execute_single_native_tool_call(
|
||||
error_event_emitted = False
|
||||
if hook_blocked:
|
||||
result = f"Tool execution blocked by hook. Tool: {func_name}"
|
||||
raw_tool_result = result
|
||||
elif not from_cache:
|
||||
if func_name in available_functions:
|
||||
if func_name in available_functions and output_tool is not None:
|
||||
try:
|
||||
tool_func = available_functions[func_name]
|
||||
raw_result = tool_func(**args_dict)
|
||||
raw_tool_result = raw_result
|
||||
|
||||
if tools_handler and tools_handler.cache:
|
||||
should_cache = True
|
||||
@@ -1529,11 +1544,10 @@ def execute_single_native_tool_call(
|
||||
tool=func_name, input=input_str, output=raw_result
|
||||
)
|
||||
|
||||
result = (
|
||||
str(raw_result) if not isinstance(raw_result, str) else raw_result
|
||||
)
|
||||
result = format_native_tool_output_for_agent(output_tool, raw_result)
|
||||
except Exception as e:
|
||||
result = f"Error executing tool: {e}"
|
||||
raw_tool_result = result
|
||||
if task:
|
||||
task.increment_tools_errors()
|
||||
crewai_event_bus.emit(
|
||||
@@ -1559,6 +1573,7 @@ def execute_single_native_tool_call(
|
||||
task=task,
|
||||
crew=crew,
|
||||
tool_result=result,
|
||||
raw_tool_result=raw_tool_result,
|
||||
)
|
||||
try:
|
||||
for after_hook in get_after_tool_call_hooks():
|
||||
|
||||
@@ -116,6 +116,7 @@ async def aexecute_tool_and_check_finality(
|
||||
logger.log("error", f"Error in before_tool_call hook: {e}")
|
||||
|
||||
tool_result = await tool_usage.ause(tool_calling, agent_action.text)
|
||||
raw_tool_result = getattr(tool_usage, "last_raw_result", tool_result)
|
||||
|
||||
after_hook_context = ToolCallHookContext(
|
||||
tool_name=sanitized_tool_name,
|
||||
@@ -125,6 +126,7 @@ async def aexecute_tool_and_check_finality(
|
||||
task=task,
|
||||
crew=crew,
|
||||
tool_result=tool_result,
|
||||
raw_tool_result=raw_tool_result,
|
||||
)
|
||||
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
@@ -234,6 +236,7 @@ def execute_tool_and_check_finality(
|
||||
logger.log("error", f"Error in before_tool_call hook: {e}")
|
||||
|
||||
tool_result = tool_usage.use(tool_calling, agent_action.text)
|
||||
raw_tool_result = getattr(tool_usage, "last_raw_result", tool_result)
|
||||
|
||||
after_hook_context = ToolCallHookContext(
|
||||
tool_name=sanitized_tool_name,
|
||||
@@ -243,6 +246,7 @@ def execute_tool_and_check_finality(
|
||||
task=task,
|
||||
crew=crew,
|
||||
tool_result=tool_result,
|
||||
raw_tool_result=raw_tool_result,
|
||||
)
|
||||
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
|
||||
@@ -7,6 +7,7 @@ when the LLM supports it, across multiple providers.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@@ -20,7 +21,7 @@ from crewai import Agent, Crew, Task
|
||||
from crewai.agents.parser import AgentFinish
|
||||
from crewai.events import crewai_event_bus
|
||||
from crewai.hooks import register_after_tool_call_hook, register_before_tool_call_hook
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext, clear_after_tool_call_hooks
|
||||
from crewai.llm import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
@@ -1197,6 +1198,76 @@ class TestNativeToolCallingJsonParseError:
|
||||
|
||||
assert result["result"] == "ran: print(1)"
|
||||
|
||||
def test_typed_output_is_json_agent_text(self) -> None:
|
||||
class SearchOutput(BaseModel):
|
||||
query: str
|
||||
score: float
|
||||
|
||||
class TypedSearchTool(BaseTool):
|
||||
name: str = "typed_search"
|
||||
description: str = "Search for information"
|
||||
output_schema: type[BaseModel] = SearchOutput
|
||||
|
||||
def _run(self, query: str) -> SearchOutput:
|
||||
return SearchOutput(query=query, score=0.8)
|
||||
|
||||
tool = TypedSearchTool()
|
||||
executor = self._make_executor([tool])
|
||||
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||
|
||||
_, available_functions, _ = convert_tools_to_openai_schema([tool])
|
||||
|
||||
result = executor._execute_single_native_tool_call(
|
||||
call_id="call_typed",
|
||||
func_name="typed_search",
|
||||
func_args='{"query": "crew"}',
|
||||
available_functions=available_functions,
|
||||
)
|
||||
|
||||
assert json.loads(result["result"]) == {"query": "crew", "score": 0.8}
|
||||
|
||||
def test_typed_output_after_hook_includes_raw_tool_result(self) -> None:
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||
|
||||
class SearchOutput(BaseModel):
|
||||
query: str
|
||||
score: float
|
||||
|
||||
class TypedSearchTool(BaseTool):
|
||||
name: str = "typed_search"
|
||||
description: str = "Search for information"
|
||||
output_schema: type[BaseModel] = SearchOutput
|
||||
|
||||
def _run(self, query: str) -> SearchOutput:
|
||||
return SearchOutput(query=query, score=0.8)
|
||||
|
||||
seen_results: list[tuple[str | None, object]] = []
|
||||
|
||||
def after_hook(context: ToolCallHookContext) -> None:
|
||||
seen_results.append((context.tool_result, context.raw_tool_result))
|
||||
|
||||
tool = TypedSearchTool()
|
||||
executor = self._make_executor([tool])
|
||||
_, available_functions, _ = convert_tools_to_openai_schema([tool])
|
||||
|
||||
clear_after_tool_call_hooks()
|
||||
register_after_tool_call_hook(after_hook)
|
||||
try:
|
||||
result = executor._execute_single_native_tool_call(
|
||||
call_id="call_typed",
|
||||
func_name="typed_search",
|
||||
func_args='{"query": "crew"}',
|
||||
available_functions=available_functions,
|
||||
)
|
||||
finally:
|
||||
clear_after_tool_call_hooks()
|
||||
|
||||
assert json.loads(result["result"]) == {"query": "crew", "score": 0.8}
|
||||
assert seen_results == [
|
||||
('{"query":"crew","score":0.8}', SearchOutput(query="crew", score=0.8))
|
||||
]
|
||||
|
||||
def test_native_tool_loop_falls_back_when_provider_rejects_tools(self) -> None:
|
||||
"""Unsupported native tools errors should continue through ReAct."""
|
||||
|
||||
|
||||
@@ -91,20 +91,24 @@ class TestToolCallHookContext:
|
||||
assert context.task == mock_task
|
||||
assert context.crew == mock_crew
|
||||
assert context.tool_result is None
|
||||
assert context.raw_tool_result is None
|
||||
|
||||
def test_context_with_result(self, mock_tool):
|
||||
"""Test that context includes result when provided."""
|
||||
tool_input = {"arg1": "value1"}
|
||||
tool_result = "Test tool result"
|
||||
raw_tool_result = {"value": 42}
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result=tool_result,
|
||||
raw_tool_result=raw_tool_result,
|
||||
)
|
||||
|
||||
assert context.tool_result == tool_result
|
||||
assert context.raw_tool_result == raw_tool_result
|
||||
|
||||
def test_tool_input_is_mutable_reference(self, mock_tool):
|
||||
"""Test that modifying context.tool_input modifies the original dict."""
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.tools import BaseTool, tool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -351,6 +352,262 @@ class TestToolDecoratorRunValidation:
|
||||
assert result == "Hello, World!"
|
||||
|
||||
|
||||
class SearchOutput(BaseModel):
|
||||
query: str
|
||||
score: float
|
||||
|
||||
|
||||
class SearchResults(RootModel[list[SearchOutput]]):
|
||||
pass
|
||||
|
||||
|
||||
class ExplicitSearchTool(BaseTool):
|
||||
name: str = "search"
|
||||
description: str = "Search for a query"
|
||||
output_schema: type[BaseModel] = SearchOutput
|
||||
|
||||
def _run(self, query: str) -> dict[str, object]:
|
||||
return {"query": query, "score": 0.8}
|
||||
|
||||
|
||||
class InferredSearchTool(BaseTool):
|
||||
name: str = "search"
|
||||
description: str = "Search for a query"
|
||||
|
||||
def _run(self, query: str) -> SearchOutput:
|
||||
return SearchOutput(query=query, score=0.7)
|
||||
|
||||
|
||||
class RootSearchTool(BaseTool):
|
||||
name: str = "search"
|
||||
description: str = "Search for a query"
|
||||
|
||||
def _run(self, query: str) -> SearchResults:
|
||||
return SearchResults([SearchOutput(query=query, score=1.0)])
|
||||
|
||||
|
||||
class DictAnnotatedSearchTool(BaseTool):
|
||||
name: str = "search"
|
||||
description: str = "Search for a query"
|
||||
|
||||
def _run(self, query: str) -> dict[str, object]:
|
||||
return {"query": query, "score": 0.5}
|
||||
|
||||
|
||||
def _make_explicit_decorator_tool() -> BaseTool:
|
||||
@tool("search", output_schema=SearchOutput)
|
||||
def search(query: str) -> dict[str, object]:
|
||||
"""Search for a query."""
|
||||
return {"query": query, "score": 0.8}
|
||||
|
||||
return search
|
||||
|
||||
|
||||
def _make_inferred_decorator_tool() -> BaseTool:
|
||||
@tool("search")
|
||||
def search(query: str) -> SearchOutput:
|
||||
"""Search for a query."""
|
||||
return SearchOutput(query=query, score=0.6)
|
||||
|
||||
return search
|
||||
|
||||
|
||||
def _make_root_decorator_tool() -> BaseTool:
|
||||
@tool("search")
|
||||
def search(query: str) -> SearchResults:
|
||||
"""Search for a query."""
|
||||
return SearchResults([SearchOutput(query=query, score=1.0)])
|
||||
|
||||
return search
|
||||
|
||||
|
||||
class TestToolOutputSchema:
|
||||
@pytest.mark.parametrize(
|
||||
("tool_cls", "expected_raw", "expected_agent_payload"),
|
||||
[
|
||||
pytest.param(
|
||||
ExplicitSearchTool,
|
||||
{"query": "crew", "score": 0.8},
|
||||
{"query": "crew", "score": 0.8},
|
||||
id="explicit-schema",
|
||||
),
|
||||
pytest.param(
|
||||
InferredSearchTool,
|
||||
SearchOutput(query="crew", score=0.7),
|
||||
{"query": "crew", "score": 0.7},
|
||||
id="inferred-base-model",
|
||||
),
|
||||
pytest.param(
|
||||
RootSearchTool,
|
||||
SearchResults([SearchOutput(query="crew", score=1.0)]),
|
||||
[{"query": "crew", "score": 1.0}],
|
||||
id="inferred-root-model",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_base_tools_return_raw_result_and_json_agent_text(
|
||||
self,
|
||||
tool_cls: type[BaseTool],
|
||||
expected_raw: object,
|
||||
expected_agent_payload: object,
|
||||
) -> None:
|
||||
t = tool_cls()
|
||||
|
||||
raw_result = t.run(query="crew")
|
||||
|
||||
assert raw_result == expected_raw
|
||||
assert json.loads(t.format_output_for_agent(raw_result)) == (
|
||||
expected_agent_payload
|
||||
)
|
||||
|
||||
def test_base_tool_does_not_infer_non_pydantic_return_annotation(self) -> None:
|
||||
t = DictAnnotatedSearchTool()
|
||||
|
||||
raw_result = t.run(query="crew")
|
||||
|
||||
assert raw_result == {"query": "crew", "score": 0.5}
|
||||
assert t.format_output_for_agent(raw_result) == str(raw_result)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("make_tool", "expected_raw", "expected_agent_payload"),
|
||||
[
|
||||
pytest.param(
|
||||
_make_explicit_decorator_tool,
|
||||
{"query": "crew", "score": 0.8},
|
||||
{"query": "crew", "score": 0.8},
|
||||
id="explicit-schema",
|
||||
),
|
||||
pytest.param(
|
||||
_make_inferred_decorator_tool,
|
||||
SearchOutput(query="crew", score=0.6),
|
||||
{"query": "crew", "score": 0.6},
|
||||
id="inferred-base-model",
|
||||
),
|
||||
pytest.param(
|
||||
_make_root_decorator_tool,
|
||||
SearchResults([SearchOutput(query="crew", score=1.0)]),
|
||||
[{"query": "crew", "score": 1.0}],
|
||||
id="inferred-root-model",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_decorator_tools_return_raw_result_and_json_agent_text(
|
||||
self,
|
||||
make_tool: Callable[[], BaseTool],
|
||||
expected_raw: object,
|
||||
expected_agent_payload: object,
|
||||
) -> None:
|
||||
search = make_tool()
|
||||
|
||||
raw_result = search.run(query="crew")
|
||||
|
||||
assert raw_result == expected_raw
|
||||
assert json.loads(search.format_output_for_agent(raw_result)) == (
|
||||
expected_agent_payload
|
||||
)
|
||||
|
||||
def test_decorator_tool_does_not_infer_non_pydantic_return_annotation(
|
||||
self,
|
||||
) -> None:
|
||||
@tool("search")
|
||||
def search(query: str) -> dict[str, object]:
|
||||
"""Search for a query."""
|
||||
return {"query": query, "score": 0.5}
|
||||
|
||||
raw_result = search.run(query="crew")
|
||||
|
||||
assert raw_result == {"query": "crew", "score": 0.5}
|
||||
assert search.format_output_for_agent(raw_result) == str(raw_result)
|
||||
|
||||
def test_explicit_output_schema_wins_over_return_annotation(self) -> None:
|
||||
class AlternateOutput(BaseModel):
|
||||
value: str
|
||||
|
||||
@tool("search", output_schema=AlternateOutput)
|
||||
def search(query: str) -> SearchOutput:
|
||||
"""Search for a query."""
|
||||
return SearchOutput(query=query, score=0.6)
|
||||
|
||||
raw_result = search.run(query="crew")
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="AlternateOutput"):
|
||||
agent_text = search.format_output_for_agent(raw_result)
|
||||
|
||||
assert raw_result == SearchOutput(query="crew", score=0.6)
|
||||
assert agent_text == str(raw_result)
|
||||
|
||||
def test_invalid_typed_output_warns_and_uses_string_agent_text(
|
||||
self,
|
||||
) -> None:
|
||||
@tool("search", output_schema=SearchOutput)
|
||||
def search(query: str) -> dict[str, object]:
|
||||
"""Search for a query."""
|
||||
return {"query": query, "score": "not-a-float"}
|
||||
|
||||
raw_result = search.run(query="crew")
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"):
|
||||
agent_text = search.format_output_for_agent(raw_result)
|
||||
|
||||
assert raw_result == {"query": "crew", "score": "not-a-float"}
|
||||
assert agent_text == str(raw_result)
|
||||
|
||||
def test_unserializable_typed_output_warns_and_uses_string_agent_text(
|
||||
self,
|
||||
) -> None:
|
||||
class OpaqueOutput(BaseModel):
|
||||
value: object
|
||||
|
||||
raw_result = OpaqueOutput(value=object())
|
||||
|
||||
@tool("opaque", output_schema=OpaqueOutput)
|
||||
def opaque() -> OpaqueOutput:
|
||||
"""Return an opaque object."""
|
||||
return raw_result
|
||||
|
||||
result = opaque.run()
|
||||
|
||||
with pytest.warns(RuntimeWarning, match="Failed to validate or serialize"):
|
||||
agent_text = opaque.format_output_for_agent(result)
|
||||
|
||||
assert result is raw_result
|
||||
assert agent_text == str(raw_result)
|
||||
|
||||
def test_output_schema_behavior_carries_over_to_structured_tool(self) -> None:
|
||||
structured = ExplicitSearchTool().to_structured_tool()
|
||||
|
||||
raw_result = structured.invoke({"query": "crew"})
|
||||
|
||||
assert raw_result == {"query": "crew", "score": 0.8}
|
||||
assert json.loads(structured.format_output_for_agent(raw_result)) == {
|
||||
"query": "crew",
|
||||
"score": 0.8,
|
||||
}
|
||||
|
||||
def test_custom_agent_output_formatter_carries_over_to_structured_tool(
|
||||
self,
|
||||
) -> None:
|
||||
class MarkdownSearchTool(BaseTool):
|
||||
name: str = "markdown_search"
|
||||
description: str = "Search for information"
|
||||
output_schema: type[BaseModel] = SearchOutput
|
||||
|
||||
def _run(self, query: str) -> SearchOutput:
|
||||
return SearchOutput(query=query, score=0.8)
|
||||
|
||||
def format_output_for_agent(self, raw_result: object) -> str:
|
||||
result = self.output_schema.model_validate(raw_result)
|
||||
return f"### Search result\n\n- Query: `{result.query}`\n- Score: {result.score}"
|
||||
|
||||
structured = MarkdownSearchTool().to_structured_tool()
|
||||
|
||||
raw_result = structured.invoke({"query": "crew"})
|
||||
|
||||
assert raw_result == SearchOutput(query="crew", score=0.8)
|
||||
assert structured.format_output_for_agent(raw_result) == (
|
||||
"### Search result\n\n- Query: `crew`\n- Score: 0.8"
|
||||
)
|
||||
|
||||
# Async arun() Schema Validation Tests
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -86,6 +88,118 @@ def test_from_function(basic_function):
|
||||
assert isinstance(tool.args_schema, type(BaseModel))
|
||||
|
||||
|
||||
class StructuredOutput(BaseModel):
|
||||
value: str
|
||||
count: int
|
||||
|
||||
|
||||
class StructuredOutputList(RootModel[list[StructuredOutput]]):
|
||||
pass
|
||||
|
||||
|
||||
def _build_explicit_structured_value(value: str) -> dict[str, object]:
|
||||
"""Build a value."""
|
||||
return {"value": value, "count": 1}
|
||||
|
||||
|
||||
def _build_inferred_structured_value(value: str) -> StructuredOutput:
|
||||
"""Build a value."""
|
||||
return StructuredOutput(value=value, count=1)
|
||||
|
||||
|
||||
def _build_structured_values(value: str) -> StructuredOutputList:
|
||||
"""Build values."""
|
||||
return StructuredOutputList([StructuredOutput(value=value, count=1)])
|
||||
|
||||
|
||||
def _build_plain_structured_value(value: str) -> dict[str, object]:
|
||||
"""Build a value."""
|
||||
return {"value": value, "count": 1}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("func", "output_schema", "expected_raw", "expected_agent_payload"),
|
||||
[
|
||||
pytest.param(
|
||||
_build_explicit_structured_value,
|
||||
StructuredOutput,
|
||||
{"value": "crew", "count": 1},
|
||||
{"value": "crew", "count": 1},
|
||||
id="explicit-schema",
|
||||
),
|
||||
pytest.param(
|
||||
_build_inferred_structured_value,
|
||||
None,
|
||||
StructuredOutput(value="crew", count=1),
|
||||
{"value": "crew", "count": 1},
|
||||
id="inferred-base-model",
|
||||
),
|
||||
pytest.param(
|
||||
_build_structured_values,
|
||||
None,
|
||||
StructuredOutputList([StructuredOutput(value="crew", count=1)]),
|
||||
[{"value": "crew", "count": 1}],
|
||||
id="inferred-root-model",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_from_function_returns_raw_result_and_json_agent_text(
|
||||
func,
|
||||
output_schema,
|
||||
expected_raw,
|
||||
expected_agent_payload,
|
||||
):
|
||||
kwargs = {"output_schema": output_schema} if output_schema is not None else {}
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=func,
|
||||
name="build_value",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
raw_result = tool.invoke({"value": "crew"})
|
||||
|
||||
assert raw_result == expected_raw
|
||||
assert json.loads(tool.format_output_for_agent(raw_result)) == (
|
||||
expected_agent_payload
|
||||
)
|
||||
|
||||
|
||||
def test_from_function_does_not_infer_non_pydantic_output_schema():
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=_build_plain_structured_value,
|
||||
name="build_value",
|
||||
)
|
||||
|
||||
raw_result = tool.invoke({"value": "crew"})
|
||||
|
||||
assert raw_result == {"value": "crew", "count": 1}
|
||||
assert tool.format_output_for_agent(raw_result) == str(raw_result)
|
||||
|
||||
|
||||
def test_invalid_typed_output_warns_and_uses_string_agent_text():
|
||||
def build_value(value: str) -> dict[str, object]:
|
||||
"""Build a value."""
|
||||
return {"value": value, "count": "wrong"}
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=build_value,
|
||||
name="build_value",
|
||||
output_schema=StructuredOutput,
|
||||
)
|
||||
raw_result = tool.invoke({"value": "crew"})
|
||||
|
||||
with pytest.warns(
|
||||
RuntimeWarning, match="Failed to validate or serialize"
|
||||
) as warnings:
|
||||
agent_text = tool.format_output_for_agent(raw_result)
|
||||
|
||||
assert raw_result == {"value": "crew", "count": "wrong"}
|
||||
assert agent_text == str(raw_result)
|
||||
warning_message = str(warnings[0].message)
|
||||
assert "ValidationError" in warning_message
|
||||
assert "wrong" not in warning_message
|
||||
|
||||
|
||||
def test_validate_function_signature(basic_function, schema_class):
|
||||
"""Test function signature validation"""
|
||||
tool = CrewStructuredTool(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
import random
|
||||
import threading
|
||||
@@ -6,6 +7,9 @@ import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai import Agent, Task
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.parser import AgentAction
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.tool_usage_events import (
|
||||
ToolSelectionErrorEvent,
|
||||
@@ -14,8 +18,15 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
from crewai.hooks.tool_hooks import (
|
||||
ToolCallHookContext,
|
||||
clear_after_tool_call_hooks,
|
||||
register_after_tool_call_hook,
|
||||
)
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.tool_calling import ToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from pydantic import BaseModel, Field
|
||||
import pytest
|
||||
|
||||
@@ -38,6 +49,19 @@ class RandomNumberTool(BaseTool):
|
||||
return random.randint(min_value, max_value) # noqa: S311
|
||||
|
||||
|
||||
class SearchOutput(BaseModel):
|
||||
query: str
|
||||
score: float
|
||||
|
||||
|
||||
class TypedSearchTool(BaseTool):
|
||||
name: str = "typed_search"
|
||||
description: str = "Search for a query"
|
||||
|
||||
def _run(self, query: str) -> SearchOutput:
|
||||
return SearchOutput(query=query, score=0.7)
|
||||
|
||||
|
||||
# Example agent and task
|
||||
example_agent = Agent(
|
||||
role="Number Generator",
|
||||
@@ -117,6 +141,109 @@ def test_tool_usage_render():
|
||||
assert '"description": "The maximum value of the range (inclusive)"' in rendered
|
||||
|
||||
|
||||
def test_tool_usage_returns_json_agent_text_for_typed_output():
|
||||
tool = TypedSearchTool().to_structured_tool()
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=None,
|
||||
tools=[tool],
|
||||
task=None,
|
||||
function_calling_llm=MagicMock(),
|
||||
agent=None,
|
||||
action=MagicMock(),
|
||||
)
|
||||
|
||||
result = tool_usage.use(
|
||||
calling=ToolCalling(
|
||||
tool_name="typed_search",
|
||||
arguments={"query": "crew"},
|
||||
),
|
||||
tool_string='Action: typed_search\nAction Input: {"query": "crew"}',
|
||||
)
|
||||
|
||||
assert json.loads(result) == {"query": "crew", "score": 0.7}
|
||||
|
||||
|
||||
def test_tool_usage_cache_callback_receives_raw_typed_output():
|
||||
raw_results: list[object] = []
|
||||
|
||||
def cache_result(_args: object, result: object) -> bool:
|
||||
raw_results.append(result)
|
||||
return True
|
||||
|
||||
class CacheAwareTypedSearchTool(TypedSearchTool):
|
||||
cache_function: Callable = cache_result
|
||||
|
||||
tools_handler = MagicMock()
|
||||
tools_handler.cache = None
|
||||
tools_handler.last_used_tool = None
|
||||
tool = CacheAwareTypedSearchTool().to_structured_tool()
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=tools_handler,
|
||||
tools=[tool],
|
||||
task=None,
|
||||
function_calling_llm=MagicMock(),
|
||||
agent=None,
|
||||
action=MagicMock(),
|
||||
)
|
||||
|
||||
result = tool_usage.use(
|
||||
calling=ToolCalling(
|
||||
tool_name="typed_search",
|
||||
arguments={"query": "crew"},
|
||||
),
|
||||
tool_string='Action: typed_search\nAction Input: {"query": "crew"}',
|
||||
)
|
||||
|
||||
assert json.loads(result) == {"query": "crew", "score": 0.7}
|
||||
assert raw_results == [SearchOutput(query="crew", score=0.7)]
|
||||
tools_handler.on_tool_use.assert_called_once()
|
||||
assert tools_handler.on_tool_use.call_args.kwargs["output"] == SearchOutput(
|
||||
query="crew",
|
||||
score=0.7,
|
||||
)
|
||||
|
||||
|
||||
def test_react_tool_hooks_receive_agent_text_and_raw_cached_typed_output():
|
||||
structured_tool = TypedSearchTool().to_structured_tool()
|
||||
tools_handler = ToolsHandler(cache=CacheHandler())
|
||||
seen_results: list[tuple[str | None, object]] = []
|
||||
|
||||
def after_hook(context: ToolCallHookContext) -> None:
|
||||
seen_results.append((context.tool_result, context.raw_tool_result))
|
||||
|
||||
clear_after_tool_call_hooks()
|
||||
register_after_tool_call_hook(after_hook)
|
||||
|
||||
action = AgentAction(
|
||||
thought="",
|
||||
tool="typed_search",
|
||||
tool_input='{"query": "crew"}',
|
||||
text='Action: typed_search\nAction Input: {"query": "crew"}',
|
||||
)
|
||||
|
||||
try:
|
||||
first = execute_tool_and_check_finality(
|
||||
agent_action=action,
|
||||
tools=[structured_tool],
|
||||
tools_handler=tools_handler,
|
||||
)
|
||||
tools_handler.last_used_tool = None
|
||||
second = execute_tool_and_check_finality(
|
||||
agent_action=action,
|
||||
tools=[structured_tool],
|
||||
tools_handler=tools_handler,
|
||||
)
|
||||
finally:
|
||||
clear_after_tool_call_hooks()
|
||||
|
||||
assert json.loads(first.result) == {"query": "crew", "score": 0.7}
|
||||
assert json.loads(second.result) == {"query": "crew", "score": 0.7}
|
||||
assert seen_results == [
|
||||
('{"query":"crew","score":0.7}', SearchOutput(query="crew", score=0.7)),
|
||||
('{"query":"crew","score":0.7}', SearchOutput(query="crew", score=0.7)),
|
||||
]
|
||||
|
||||
|
||||
def test_validate_tool_input_booleans_and_none():
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=MagicMock(),
|
||||
|
||||
@@ -3,12 +3,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Literal, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.hooks.tool_hooks import (
|
||||
ToolCallHookContext,
|
||||
clear_after_tool_call_hooks,
|
||||
clear_before_tool_call_hooks,
|
||||
register_after_tool_call_hook,
|
||||
)
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.agent_utils import (
|
||||
_asummarize_chunks,
|
||||
@@ -1030,6 +1037,142 @@ class TestParseToolCallArgs:
|
||||
class TestExecuteSingleNativeToolCall:
|
||||
"""Tests for execute_single_native_tool_call."""
|
||||
|
||||
def test_typed_tool_output_is_json_agent_text(self) -> None:
|
||||
clear_before_tool_call_hooks()
|
||||
clear_after_tool_call_hooks()
|
||||
|
||||
class SearchOutput(BaseModel):
|
||||
query: str
|
||||
score: float
|
||||
|
||||
class TypedSearchTool(BaseTool):
|
||||
name: str = "typed_search"
|
||||
description: str = "Search for a query"
|
||||
output_schema: type[BaseModel] = SearchOutput
|
||||
|
||||
def _run(self, query: str) -> SearchOutput:
|
||||
return SearchOutput(query=query, score=0.9)
|
||||
|
||||
tool = TypedSearchTool()
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "call_1"
|
||||
tool_call.function.name = "typed_search"
|
||||
tool_call.function.arguments = '{"query": "crew"}'
|
||||
|
||||
result = execute_single_native_tool_call(
|
||||
tool_call,
|
||||
available_functions={"typed_search": tool._run},
|
||||
original_tools=[tool],
|
||||
structured_tools=[tool.to_structured_tool()],
|
||||
tools_handler=None,
|
||||
agent=None,
|
||||
task=None,
|
||||
crew=None,
|
||||
event_source=MagicMock(),
|
||||
printer=None,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
assert json.loads(result.result) == {"query": "crew", "score": 0.9}
|
||||
assert json.loads(result.tool_message["content"]) == {
|
||||
"query": "crew",
|
||||
"score": 0.9,
|
||||
}
|
||||
|
||||
def test_custom_agent_output_formatter_is_used_from_structured_tool(
|
||||
self,
|
||||
) -> None:
|
||||
clear_before_tool_call_hooks()
|
||||
clear_after_tool_call_hooks()
|
||||
|
||||
class SearchOutput(BaseModel):
|
||||
query: str
|
||||
score: float
|
||||
|
||||
class MarkdownSearchTool(BaseTool):
|
||||
name: str = "markdown_search"
|
||||
description: str = "Search for a query"
|
||||
output_schema: type[BaseModel] = SearchOutput
|
||||
|
||||
def _run(self, query: str) -> SearchOutput:
|
||||
return SearchOutput(query=query, score=0.9)
|
||||
|
||||
def format_output_for_agent(self, raw_result: Any) -> str:
|
||||
result = self.output_schema.model_validate(raw_result)
|
||||
return f"### {result.query}\n\nScore: **{result.score}**"
|
||||
|
||||
tool = MarkdownSearchTool()
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "call_1"
|
||||
tool_call.function.name = "markdown_search"
|
||||
tool_call.function.arguments = '{"query": "crew"}'
|
||||
|
||||
result = execute_single_native_tool_call(
|
||||
tool_call,
|
||||
available_functions={"markdown_search": tool._run},
|
||||
original_tools=[],
|
||||
structured_tools=[tool.to_structured_tool()],
|
||||
tools_handler=None,
|
||||
agent=None,
|
||||
task=None,
|
||||
crew=None,
|
||||
event_source=MagicMock(),
|
||||
printer=None,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
assert result.result == "### crew\n\nScore: **0.9**"
|
||||
assert result.tool_message["content"] == "### crew\n\nScore: **0.9**"
|
||||
|
||||
def test_after_hook_includes_raw_tool_result_for_typed_output(self) -> None:
|
||||
clear_after_tool_call_hooks()
|
||||
|
||||
class SearchOutput(BaseModel):
|
||||
query: str
|
||||
score: float
|
||||
|
||||
class TypedSearchTool(BaseTool):
|
||||
name: str = "typed_search"
|
||||
description: str = "Search for a query"
|
||||
output_schema: type[BaseModel] = SearchOutput
|
||||
|
||||
def _run(self, query: str) -> SearchOutput:
|
||||
return SearchOutput(query=query, score=0.9)
|
||||
|
||||
seen_results: list[tuple[str | None, object]] = []
|
||||
|
||||
def after_hook(context: ToolCallHookContext) -> None:
|
||||
seen_results.append((context.tool_result, context.raw_tool_result))
|
||||
|
||||
tool = TypedSearchTool()
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "call_1"
|
||||
tool_call.function.name = "typed_search"
|
||||
tool_call.function.arguments = '{"query": "crew"}'
|
||||
|
||||
register_after_tool_call_hook(after_hook)
|
||||
try:
|
||||
result = execute_single_native_tool_call(
|
||||
tool_call,
|
||||
available_functions={"typed_search": tool._run},
|
||||
original_tools=[tool],
|
||||
structured_tools=[tool.to_structured_tool()],
|
||||
tools_handler=None,
|
||||
agent=None,
|
||||
task=None,
|
||||
crew=None,
|
||||
event_source=MagicMock(),
|
||||
printer=None,
|
||||
verbose=False,
|
||||
)
|
||||
finally:
|
||||
clear_after_tool_call_hooks()
|
||||
|
||||
assert json.loads(result.result) == {"query": "crew", "score": 0.9}
|
||||
assert seen_results == [
|
||||
('{"query":"crew","score":0.9}', SearchOutput(query="crew", score=0.9))
|
||||
]
|
||||
|
||||
def test_result_as_answer_false_on_tool_error(self) -> None:
|
||||
"""When a tool with result_as_answer=True raises, result_as_answer must be False.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user