feat: native async tool support
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

- add async support for tools
- add async tool tests
- improve tool decorator typing
- fix _run backward compatibility
- update docs and improve readability of docstrings
This commit is contained in:
Greyson LaLonde
2025-12-02 16:39:58 -05:00
committed by GitHub
parent 20704742e2
commit 09f1ba6956
7 changed files with 938 additions and 33 deletions

View File

@@ -2,9 +2,18 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from inspect import signature
from typing import Any, cast, get_args, get_origin
from typing import (
Any,
Generic,
ParamSpec,
TypeVar,
cast,
get_args,
get_origin,
overload,
)
from pydantic import (
BaseModel,
@@ -14,6 +23,7 @@ from pydantic import (
create_model,
field_validator,
)
from typing_extensions import TypeIs
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities.printer import Printer
@@ -21,6 +31,19 @@ from crewai.utilities.printer import Printer
_printer = Printer()
P = ParamSpec("P")
R = TypeVar("R", covariant=True)
def _is_async_callable(func: Callable[..., Any]) -> bool:
"""Check if a callable is async."""
return asyncio.iscoroutinefunction(func)
def _is_awaitable(value: R | Awaitable[R]) -> TypeIs[Awaitable[R]]:
"""Type narrowing check for awaitable values."""
return asyncio.iscoroutine(value) or asyncio.isfuture(value)
class EnvVar(BaseModel):
name: str
@@ -55,7 +78,7 @@ class BaseTool(BaseModel, ABC):
default=False, description="Flag to check if the description has been updated."
)
cache_function: Callable = Field(
cache_function: Callable[..., bool] = Field(
default=lambda _args=None, _result=None: True,
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
)
@@ -123,6 +146,35 @@ class BaseTool(BaseModel, ABC):
return result
async def arun(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Execute the tool asynchronously.
Args:
*args: Positional arguments to pass to the tool.
**kwargs: Keyword arguments to pass to the tool.
Returns:
The result of the tool execution.
"""
result = await self._arun(*args, **kwargs)
self.current_usage_count += 1
return result
async def _arun(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Async implementation of the tool. Override for async support."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement _arun. "
"Override _arun for async support or use run() for sync execution."
)
def reset_usage_count(self) -> None:
"""Reset the current usage count to zero."""
self.current_usage_count = 0
@@ -133,7 +185,17 @@ class BaseTool(BaseModel, ABC):
*args: Any,
**kwargs: Any,
) -> Any:
"""Here goes the actual implementation of the tool."""
"""Sync implementation of the tool.
Subclasses must implement this method for synchronous execution.
Args:
*args: Positional arguments for the tool.
**kwargs: Keyword arguments for the tool.
Returns:
The result of the tool execution.
"""
def to_structured_tool(self) -> CrewStructuredTool:
"""Convert this tool to a CrewStructuredTool instance."""
@@ -239,21 +301,90 @@ class BaseTool(BaseModel, ABC):
if args:
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
return str(f"{origin.__name__}[{args_str}]")
return origin.__name__
return str(origin.__name__)
class Tool(BaseTool):
"""The function that will be executed when the tool is called."""
class Tool(BaseTool, Generic[P, R]):
"""Tool that wraps a callable function.
func: Callable
def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
Type Parameters:
P: ParamSpec capturing the function's parameters.
R: The return type of the function.
"""
func: Callable[P, R | Awaitable[R]]
def run(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the tool synchronously.
Args:
*args: Positional arguments for the tool.
**kwargs: Keyword arguments for the tool.
Returns:
The result of the tool execution.
"""
_printer.print(f"Using Tool: {self.name}", color="cyan")
result = self.func(*args, **kwargs)
if asyncio.iscoroutine(result):
result = asyncio.run(result)
self.current_usage_count += 1
return result # type: ignore[return-value]
def _run(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the wrapped function.
Args:
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
The result of the function execution.
"""
return self.func(*args, **kwargs) # type: ignore[return-value]
async def arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the tool asynchronously.
Args:
*args: Positional arguments for the tool.
**kwargs: Keyword arguments for the tool.
Returns:
The result of the tool execution.
"""
result = await self._arun(*args, **kwargs)
self.current_usage_count += 1
return result
async def _arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the wrapped function asynchronously.
Args:
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
The result of the async function execution.
Raises:
NotImplementedError: If the wrapped function is not async.
"""
result = self.func(*args, **kwargs)
if _is_awaitable(result):
return await result
raise NotImplementedError(
f"{self.name} does not have an async function. "
"Use run() for sync execution or provide an async function."
)
@classmethod
def from_langchain(cls, tool: Any) -> Tool:
def from_langchain(cls, tool: Any) -> Tool[..., Any]:
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
@@ -261,10 +392,10 @@ class Tool(BaseTool):
attribute and infers the argument schema if not explicitly provided.
Args:
tool (Any): The CrewStructuredTool object to be converted.
tool: The CrewStructuredTool object to be converted.
Returns:
Tool: A new Tool instance created from the provided CrewStructuredTool.
A new Tool instance created from the provided CrewStructuredTool.
Raises:
ValueError: If the provided tool does not have a callable 'func' attribute.
@@ -308,37 +439,83 @@ class Tool(BaseTool):
def to_langchain(
tools: list[BaseTool | CrewStructuredTool],
) -> list[CrewStructuredTool]:
"""Convert a list of tools to CrewStructuredTool instances."""
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
P2 = ParamSpec("P2")
R2 = TypeVar("R2")
@overload
def tool(func: Callable[P2, R2], /) -> Tool[P2, R2]: ...
@overload
def tool(
*args, result_as_answer: bool = False, max_usage_count: int | None = None
) -> Callable:
"""
Decorator to create a tool from a function.
name: str,
/,
*,
result_as_answer: bool = ...,
max_usage_count: int | None = ...,
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
@overload
def tool(
*,
result_as_answer: bool = ...,
max_usage_count: int | None = ...,
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
def tool(
*args: Callable[P2, R2] | str,
result_as_answer: bool = False,
max_usage_count: int | None = None,
) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]:
"""Decorator to create a Tool from a function.
Can be used in three ways:
1. @tool - decorator without arguments, uses function name
2. @tool("name") - decorator with custom name
3. @tool(result_as_answer=True) - decorator with options
Args:
*args: Positional arguments, either the function to decorate or the tool name.
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
*args: Either the function to decorate or a custom tool name.
result_as_answer: If True, the tool result becomes the final agent answer.
max_usage_count: Maximum times this tool can be used. None means unlimited.
Returns:
A Tool instance.
Example:
@tool
def greet(name: str) -> str:
'''Greet someone.'''
return f"Hello, {name}!"
result = greet.run("World")
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(f: Callable) -> BaseTool:
def _make_with_name(tool_name: str) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]:
def _make_tool(f: Callable[P2, R2]) -> Tool[P2, R2]:
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
if f.__annotations__ is None:
func_annotations = getattr(f, "__annotations__", None)
if func_annotations is None:
raise ValueError("Function must have type annotations")
class_name = "".join(tool_name.split()).title()
args_schema = cast(
tool_args_schema = cast(
type[PydanticBaseModel],
type(
class_name,
(PydanticBaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != "return"
k: v for k, v in func_annotations.items() if k != "return"
},
},
),
@@ -348,10 +525,9 @@ def tool(
name=tool_name,
description=f.__doc__,
func=f,
args_schema=args_schema,
args_schema=tool_args_schema,
result_as_answer=result_as_answer,
max_usage_count=max_usage_count,
current_usage_count=0,
)
return _make_tool
@@ -360,4 +536,10 @@ def tool(
return _make_with_name(args[0].__name__)(args[0])
if len(args) == 1 and isinstance(args[0], str):
return _make_with_name(args[0])
if len(args) == 0:
def decorator(f: Callable[P2, R2]) -> Tool[P2, R2]:
return _make_with_name(f.__name__)(f)
return decorator
raise ValueError("Invalid arguments")

View File

@@ -160,6 +160,251 @@ class ToolUsage:
return f"{self._use(tool_string=tool_string, tool=tool, calling=calling)}"
async def ause(
self, calling: ToolCalling | InstructorToolCalling, tool_string: str
) -> str:
"""Execute a tool asynchronously.
Args:
calling: The tool calling information.
tool_string: The raw tool string from the agent.
Returns:
The result of the tool execution as a string.
"""
if isinstance(calling, ToolUsageError):
error = calling.message
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{error}\n", color="red")
if self.task:
self.task.increment_tools_errors()
return error
try:
tool = self._select_tool(calling.tool_name)
except Exception as e:
error = getattr(e, "message", str(e))
if self.task:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{error}\n", color="red")
return error
if (
isinstance(tool, CrewStructuredTool)
and tool.name == self._i18n.tools("add_image")["name"] # type: ignore
):
try:
return await self._ause(
tool_string=tool_string, tool=tool, calling=calling
)
except Exception as e:
error = getattr(e, "message", str(e))
if self.task:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{error}\n", color="red")
return error
return (
f"{await self._ause(tool_string=tool_string, tool=tool, calling=calling)}"
)
async def _ause(
self,
tool_string: str,
tool: CrewStructuredTool,
calling: ToolCalling | InstructorToolCalling,
) -> str:
"""Internal async tool execution implementation.
Args:
tool_string: The raw tool string from the agent.
tool: The tool to execute.
calling: The tool calling information.
Returns:
The result of the tool execution as a string.
"""
if self._check_tool_repeated_usage(calling=calling):
try:
result = self._i18n.errors("task_repeated_usage").format(
tool_names=self.tools_names
)
self._telemetry.tool_repeated_usage(
llm=self.function_calling_llm,
tool_name=tool.name,
attempts=self._run_attempts,
)
return self._format_result(result=result)
except Exception:
if self.task:
self.task.increment_tools_errors()
if self.agent:
event_data = {
"agent_key": self.agent.key,
"agent_role": self.agent.role,
"tool_name": self.action.tool,
"tool_args": self.action.tool_input,
"tool_class": self.action.tool,
"agent": self.agent,
}
if self.agent.fingerprint: # type: ignore
event_data.update(self.agent.fingerprint) # type: ignore
if self.task:
event_data["task_name"] = self.task.name or self.task.description
event_data["task_id"] = str(self.task.id)
crewai_event_bus.emit(self, ToolUsageStartedEvent(**event_data))
started_at = time.time()
from_cache = False
result = None # type: ignore
if self.tools_handler and self.tools_handler.cache:
input_str = ""
if calling.arguments:
if isinstance(calling.arguments, dict):
input_str = json.dumps(calling.arguments)
else:
input_str = str(calling.arguments)
result = self.tools_handler.cache.read(
tool=calling.tool_name, input=input_str
) # type: ignore
from_cache = result is not None
available_tool = next(
(
available_tool
for available_tool in self.tools
if available_tool.name == tool.name
),
None,
)
usage_limit_error = self._check_usage_limit(available_tool, tool.name)
if usage_limit_error:
try:
result = usage_limit_error
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
return self._format_result(result=result)
except Exception:
if self.task:
self.task.increment_tools_errors()
if result is None:
try:
if calling.tool_name in [
"Delegate work to coworker",
"Ask question to coworker",
]:
coworker = (
calling.arguments.get("coworker") if calling.arguments else None
)
if self.task:
self.task.increment_delegations(coworker)
if calling.arguments:
try:
acceptable_args = tool.args_schema.model_json_schema()[
"properties"
].keys()
arguments = {
k: v
for k, v in calling.arguments.items()
if k in acceptable_args
}
arguments = self._add_fingerprint_metadata(arguments)
result = await tool.ainvoke(input=arguments)
except Exception:
arguments = calling.arguments
arguments = self._add_fingerprint_metadata(arguments)
result = await tool.ainvoke(input=arguments)
else:
arguments = self._add_fingerprint_metadata({})
result = await tool.ainvoke(input=arguments)
except Exception as e:
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
self._run_attempts += 1
if self._run_attempts > self._max_parsing_attempts:
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
error_message = self._i18n.errors("tool_usage_exception").format(
error=e, tool=tool.name, tool_inputs=tool.description
)
error = ToolUsageError(
f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
).message
if self.task:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(
content=f"\n\n{error_message}\n", color="red"
)
return error
if self.task:
self.task.increment_tools_errors()
return await self.ause(calling=calling, tool_string=tool_string)
if self.tools_handler:
should_cache = True
if (
hasattr(available_tool, "cache_function")
and available_tool.cache_function
):
should_cache = available_tool.cache_function(
calling.arguments, result
)
self.tools_handler.on_tool_use(
calling=calling, output=result, should_cache=should_cache
)
self._telemetry.tool_usage(
llm=self.function_calling_llm,
tool_name=tool.name,
attempts=self._run_attempts,
)
result = self._format_result(result=result)
data = {
"result": result,
"tool_name": tool.name,
"tool_args": calling.arguments,
}
self.on_tool_use_finished(
tool=tool,
tool_calling=calling,
from_cache=from_cache,
started_at=started_at,
result=result,
)
if (
hasattr(available_tool, "result_as_answer")
and available_tool.result_as_answer # type: ignore
):
result_as_answer = available_tool.result_as_answer # type: ignore
data["result_as_answer"] = result_as_answer # type: ignore
if self.agent and hasattr(self.agent, "tools_results"):
self.agent.tools_results.append(data)
if available_tool and hasattr(available_tool, "current_usage_count"):
available_tool.current_usage_count += 1
if (
hasattr(available_tool, "max_usage_count")
and available_tool.max_usage_count is not None
):
self._printer.print(
content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
color="blue",
)
return result
def _use(
self,
tool_string: str,

View File

@@ -26,6 +26,138 @@ if TYPE_CHECKING:
from crewai.task import Task
async def aexecute_tool_and_check_finality(
agent_action: AgentAction,
tools: list[CrewStructuredTool],
i18n: I18N,
agent_key: str | None = None,
agent_role: str | None = None,
tools_handler: ToolsHandler | None = None,
task: Task | None = None,
agent: Agent | BaseAgent | None = None,
function_calling_llm: BaseLLM | LLM | None = None,
fingerprint_context: dict[str, str] | None = None,
crew: Crew | None = None,
) -> ToolResult:
"""Execute a tool asynchronously and check if the result should be a final answer.
This is the async version of execute_tool_and_check_finality. It integrates tool
hooks for before and after tool execution, allowing programmatic interception
and modification of tool calls.
Args:
agent_action: The action containing the tool to execute.
tools: List of available tools.
i18n: Internationalization settings.
agent_key: Optional key for event emission.
agent_role: Optional role for event emission.
tools_handler: Optional tools handler for tool execution.
task: Optional task for tool execution.
agent: Optional agent instance for tool execution.
function_calling_llm: Optional LLM for function calling.
fingerprint_context: Optional context for fingerprinting.
crew: Optional crew instance for hook context.
Returns:
ToolResult containing the execution result and whether it should be
treated as a final answer.
"""
logger = Logger(verbose=crew.verbose if crew else False)
tool_name_to_tool_map = {tool.name: tool for tool in tools}
if agent_key and agent_role and agent:
fingerprint_context = fingerprint_context or {}
if agent:
if hasattr(agent, "set_fingerprint") and callable(agent.set_fingerprint):
if isinstance(fingerprint_context, dict):
try:
fingerprint_obj = Fingerprint.from_dict(fingerprint_context)
agent.set_fingerprint(fingerprint=fingerprint_obj)
except Exception as e:
raise ValueError(f"Failed to set fingerprint: {e}") from e
tool_usage = ToolUsage(
tools_handler=tools_handler,
tools=tools,
function_calling_llm=function_calling_llm, # type: ignore[arg-type]
task=task,
agent=agent,
action=agent_action,
)
tool_calling = tool_usage.parse_tool_calling(agent_action.text)
if isinstance(tool_calling, ToolUsageError):
return ToolResult(tool_calling.message, False)
if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in tool_name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in tool_name_to_tool_map
]:
tool = tool_name_to_tool_map.get(tool_calling.tool_name)
if not tool:
tool_result = i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join([t.name.casefold() for t in tools]),
)
return ToolResult(result=tool_result, result_as_answer=False)
tool_input = tool_calling.arguments if tool_calling.arguments else {}
hook_context = ToolCallHookContext(
tool_name=tool_calling.tool_name,
tool_input=tool_input,
tool=tool,
agent=agent,
task=task,
crew=crew,
)
before_hooks = get_before_tool_call_hooks()
try:
for hook in before_hooks:
result = hook(hook_context)
if result is False:
blocked_message = (
f"Tool execution blocked by hook. "
f"Tool: {tool_calling.tool_name}"
)
return ToolResult(blocked_message, False)
except Exception as e:
logger.log("error", f"Error in before_tool_call hook: {e}")
tool_result = await tool_usage.ause(tool_calling, agent_action.text)
after_hook_context = ToolCallHookContext(
tool_name=tool_calling.tool_name,
tool_input=tool_input,
tool=tool,
agent=agent,
task=task,
crew=crew,
tool_result=tool_result,
)
after_hooks = get_after_tool_call_hooks()
modified_result: str = tool_result
try:
for after_hook in after_hooks:
hook_result = after_hook(after_hook_context)
if hook_result is not None:
modified_result = hook_result
after_hook_context.tool_result = modified_result
except Exception as e:
logger.log("error", f"Error in after_tool_call hook: {e}")
return ToolResult(modified_result, tool.result_as_answer)
tool_result = i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join([tool.name.casefold() for tool in tools]),
)
return ToolResult(result=tool_result, result_as_answer=False)
def execute_tool_and_check_finality(
agent_action: AgentAction,
tools: list[CrewStructuredTool],
@@ -141,10 +273,10 @@ def execute_tool_and_check_finality(
# Execute after_tool_call hooks
after_hooks = get_after_tool_call_hooks()
modified_result = tool_result
modified_result: str = tool_result
try:
for hook in after_hooks:
hook_result = hook(after_hook_context)
for after_hook in after_hooks:
hook_result = after_hook(after_hook_context)
if hook_result is not None:
modified_result = hook_result
after_hook_context.tool_result = modified_result

View File

@@ -0,0 +1,196 @@
"""Tests for async tool functionality."""
import asyncio
import pytest
from crewai.tools import BaseTool, tool
class SyncTool(BaseTool):
"""Test tool with synchronous _run method."""
name: str = "sync_tool"
description: str = "A synchronous tool for testing"
def _run(self, input_text: str) -> str:
"""Process input text synchronously."""
return f"Sync processed: {input_text}"
class AsyncTool(BaseTool):
"""Test tool with both sync and async implementations."""
name: str = "async_tool"
description: str = "An asynchronous tool for testing"
def _run(self, input_text: str) -> str:
"""Process input text synchronously."""
return f"Sync processed: {input_text}"
async def _arun(self, input_text: str) -> str:
"""Process input text asynchronously."""
await asyncio.sleep(0.01)
return f"Async processed: {input_text}"
class TestBaseTool:
"""Tests for BaseTool async functionality."""
def test_sync_tool_run_returns_result(self) -> None:
"""Test that sync tool run() returns correct result."""
tool = SyncTool()
result = tool.run(input_text="hello")
assert result == "Sync processed: hello"
def test_async_tool_run_returns_result(self) -> None:
"""Test that async tool run() works."""
tool = AsyncTool()
result = tool.run(input_text="hello")
assert result == "Sync processed: hello"
@pytest.mark.asyncio
async def test_sync_tool_arun_raises_not_implemented(self) -> None:
"""Test that sync tool arun() raises NotImplementedError."""
tool = SyncTool()
with pytest.raises(NotImplementedError):
await tool.arun(input_text="hello")
@pytest.mark.asyncio
async def test_async_tool_arun_returns_result(self) -> None:
"""Test that async tool arun() awaits directly."""
tool = AsyncTool()
result = await tool.arun(input_text="hello")
assert result == "Async processed: hello"
@pytest.mark.asyncio
async def test_arun_increments_usage_count(self) -> None:
"""Test that arun increments the usage count."""
tool = AsyncTool()
assert tool.current_usage_count == 0
await tool.arun(input_text="test")
assert tool.current_usage_count == 1
await tool.arun(input_text="test2")
assert tool.current_usage_count == 2
@pytest.mark.asyncio
async def test_multiple_async_tools_run_concurrently(self) -> None:
"""Test that multiple async tools can run concurrently."""
tool1 = AsyncTool()
tool2 = AsyncTool()
results = await asyncio.gather(
tool1.arun(input_text="first"),
tool2.arun(input_text="second"),
)
assert results[0] == "Async processed: first"
assert results[1] == "Async processed: second"
class TestToolDecorator:
"""Tests for @tool decorator with async functions."""
def test_sync_decorated_tool_run(self) -> None:
"""Test sync decorated tool works with run()."""
@tool("sync_decorated")
def sync_func(value: str) -> str:
"""A sync decorated tool."""
return f"sync: {value}"
result = sync_func.run(value="test")
assert result == "sync: test"
def test_async_decorated_tool_run(self) -> None:
"""Test async decorated tool works with run()."""
@tool("async_decorated")
async def async_func(value: str) -> str:
"""An async decorated tool."""
await asyncio.sleep(0.01)
return f"async: {value}"
result = async_func.run(value="test")
assert result == "async: test"
@pytest.mark.asyncio
async def test_sync_decorated_tool_arun_raises(self) -> None:
"""Test sync decorated tool arun() raises NotImplementedError."""
@tool("sync_decorated_arun")
def sync_func(value: str) -> str:
"""A sync decorated tool."""
return f"sync: {value}"
with pytest.raises(NotImplementedError):
await sync_func.arun(value="test")
@pytest.mark.asyncio
async def test_async_decorated_tool_arun(self) -> None:
"""Test async decorated tool works with arun()."""
@tool("async_decorated_arun")
async def async_func(value: str) -> str:
"""An async decorated tool."""
await asyncio.sleep(0.01)
return f"async: {value}"
result = await async_func.arun(value="test")
assert result == "async: test"
class TestAsyncToolWithIO:
"""Tests for async tools with simulated I/O operations."""
@pytest.mark.asyncio
async def test_async_tool_simulated_io(self) -> None:
"""Test async tool with simulated I/O delay."""
class SlowAsyncTool(BaseTool):
name: str = "slow_async"
description: str = "Simulates slow I/O"
def _run(self, delay: float) -> str:
return f"Completed after {delay}s"
async def _arun(self, delay: float) -> str:
await asyncio.sleep(delay)
return f"Completed after {delay}s"
tool = SlowAsyncTool()
result = await tool.arun(delay=0.05)
assert result == "Completed after 0.05s"
@pytest.mark.asyncio
async def test_multiple_slow_tools_concurrent(self) -> None:
"""Test that slow async tools benefit from concurrency."""
class SlowAsyncTool(BaseTool):
name: str = "slow_async"
description: str = "Simulates slow I/O"
def _run(self, task_id: int, delay: float) -> str:
return f"Task {task_id} done"
async def _arun(self, task_id: int, delay: float) -> str:
await asyncio.sleep(delay)
return f"Task {task_id} done"
tool = SlowAsyncTool()
import time
start = time.time()
results = await asyncio.gather(
tool.arun(task_id=1, delay=0.1),
tool.arun(task_id=2, delay=0.1),
tool.arun(task_id=3, delay=0.1),
)
elapsed = time.time() - start
assert len(results) == 3
assert all("done" in r for r in results)
assert elapsed < 0.25, f"Expected concurrent execution, took {elapsed}s"