mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 00:58:13 +00:00
feat: add async support for tools, add async tool tests
This commit is contained in:
@@ -22,6 +22,11 @@ from crewai.utilities.printer import Printer
|
|||||||
_printer = Printer()
|
_printer = Printer()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_async_callable(func: Callable[..., Any]) -> bool:
|
||||||
|
"""Check if a callable is async."""
|
||||||
|
return asyncio.iscoroutinefunction(func)
|
||||||
|
|
||||||
|
|
||||||
class EnvVar(BaseModel):
|
class EnvVar(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
@@ -55,7 +60,7 @@ class BaseTool(BaseModel, ABC):
|
|||||||
default=False, description="Flag to check if the description has been updated."
|
default=False, description="Flag to check if the description has been updated."
|
||||||
)
|
)
|
||||||
|
|
||||||
cache_function: Callable = Field(
|
cache_function: Callable[..., bool] = Field(
|
||||||
default=lambda _args=None, _result=None: True,
|
default=lambda _args=None, _result=None: True,
|
||||||
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
|
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
|
||||||
)
|
)
|
||||||
@@ -123,6 +128,35 @@ class BaseTool(BaseModel, ABC):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def arun(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Execute the tool asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Positional arguments to pass to the tool.
|
||||||
|
**kwargs: Keyword arguments to pass to the tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the tool execution.
|
||||||
|
"""
|
||||||
|
result = await self._arun(*args, **kwargs)
|
||||||
|
self.current_usage_count += 1
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Async implementation of the tool. Override for async support."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not implement _arun. "
|
||||||
|
"Override _arun for async support or use run() for sync execution."
|
||||||
|
)
|
||||||
|
|
||||||
def reset_usage_count(self) -> None:
|
def reset_usage_count(self) -> None:
|
||||||
"""Reset the current usage count to zero."""
|
"""Reset the current usage count to zero."""
|
||||||
self.current_usage_count = 0
|
self.current_usage_count = 0
|
||||||
@@ -133,7 +167,17 @@ class BaseTool(BaseModel, ABC):
|
|||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Here goes the actual implementation of the tool."""
|
"""Sync implementation of the tool.
|
||||||
|
|
||||||
|
Subclasses must implement this method for synchronous execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: Positional arguments for the tool.
|
||||||
|
**kwargs: Keyword arguments for the tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the tool execution.
|
||||||
|
"""
|
||||||
|
|
||||||
def to_structured_tool(self) -> CrewStructuredTool:
|
def to_structured_tool(self) -> CrewStructuredTool:
|
||||||
"""Convert this tool to a CrewStructuredTool instance."""
|
"""Convert this tool to a CrewStructuredTool instance."""
|
||||||
@@ -239,19 +283,32 @@ class BaseTool(BaseModel, ABC):
|
|||||||
|
|
||||||
if args:
|
if args:
|
||||||
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
|
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
|
||||||
return f"{origin.__name__}[{args_str}]"
|
return str(f"{origin.__name__}[{args_str}]")
|
||||||
|
|
||||||
return origin.__name__
|
return str(origin.__name__)
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseTool):
|
class Tool(BaseTool):
|
||||||
"""The function that will be executed when the tool is called."""
|
"""Tool that wraps a callable function.
|
||||||
|
|
||||||
func: Callable
|
The function can be either synchronous or asynchronous.
|
||||||
|
"""
|
||||||
|
|
||||||
|
func: Callable[..., Any]
|
||||||
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""Execute the wrapped function."""
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
async def _arun(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""Execute the wrapped function asynchronously."""
|
||||||
|
if _is_async_callable(self.func):
|
||||||
|
return await self.func(*args, **kwargs)
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.name} does not have an async function. "
|
||||||
|
"Use run() for sync execution or provide an async function."
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_langchain(cls, tool: Any) -> Tool:
|
def from_langchain(cls, tool: Any) -> Tool:
|
||||||
"""Create a Tool instance from a CrewStructuredTool.
|
"""Create a Tool instance from a CrewStructuredTool.
|
||||||
@@ -312,19 +369,23 @@ def to_langchain(
|
|||||||
|
|
||||||
|
|
||||||
def tool(
|
def tool(
|
||||||
*args, result_as_answer: bool = False, max_usage_count: int | None = None
|
*args: Callable[..., Any] | str,
|
||||||
) -> Callable:
|
result_as_answer: bool = False,
|
||||||
"""
|
max_usage_count: int | None = None,
|
||||||
Decorator to create a tool from a function.
|
) -> Callable[[Callable[..., Any]], Tool] | Tool:
|
||||||
|
"""Decorator to create a tool from a function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*args: Positional arguments, either the function to decorate or the tool name.
|
*args: Positional arguments, either the function to decorate or the tool name.
|
||||||
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.
|
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.
|
||||||
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Tool instance or a decorator that creates a Tool instance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _make_with_name(tool_name: str) -> Callable:
|
def _make_with_name(tool_name: str) -> Callable[[Callable[..., Any]], Tool]:
|
||||||
def _make_tool(f: Callable) -> BaseTool:
|
def _make_tool(f: Callable[..., Any]) -> Tool:
|
||||||
if f.__doc__ is None:
|
if f.__doc__ is None:
|
||||||
raise ValueError("Function must have a docstring")
|
raise ValueError("Function must have a docstring")
|
||||||
if f.__annotations__ is None:
|
if f.__annotations__ is None:
|
||||||
|
|||||||
@@ -160,6 +160,251 @@ class ToolUsage:
|
|||||||
|
|
||||||
return f"{self._use(tool_string=tool_string, tool=tool, calling=calling)}"
|
return f"{self._use(tool_string=tool_string, tool=tool, calling=calling)}"
|
||||||
|
|
||||||
|
async def ause(
|
||||||
|
self, calling: ToolCalling | InstructorToolCalling, tool_string: str
|
||||||
|
) -> str:
|
||||||
|
"""Execute a tool asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calling: The tool calling information.
|
||||||
|
tool_string: The raw tool string from the agent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the tool execution as a string.
|
||||||
|
"""
|
||||||
|
if isinstance(calling, ToolUsageError):
|
||||||
|
error = calling.message
|
||||||
|
if self.agent and self.agent.verbose:
|
||||||
|
self._printer.print(content=f"\n\n{error}\n", color="red")
|
||||||
|
if self.task:
|
||||||
|
self.task.increment_tools_errors()
|
||||||
|
return error
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool = self._select_tool(calling.tool_name)
|
||||||
|
except Exception as e:
|
||||||
|
error = getattr(e, "message", str(e))
|
||||||
|
if self.task:
|
||||||
|
self.task.increment_tools_errors()
|
||||||
|
if self.agent and self.agent.verbose:
|
||||||
|
self._printer.print(content=f"\n\n{error}\n", color="red")
|
||||||
|
return error
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(tool, CrewStructuredTool)
|
||||||
|
and tool.name == self._i18n.tools("add_image")["name"] # type: ignore
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
return await self._ause(
|
||||||
|
tool_string=tool_string, tool=tool, calling=calling
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error = getattr(e, "message", str(e))
|
||||||
|
if self.task:
|
||||||
|
self.task.increment_tools_errors()
|
||||||
|
if self.agent and self.agent.verbose:
|
||||||
|
self._printer.print(content=f"\n\n{error}\n", color="red")
|
||||||
|
return error
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"{await self._ause(tool_string=tool_string, tool=tool, calling=calling)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ause(
|
||||||
|
self,
|
||||||
|
tool_string: str,
|
||||||
|
tool: CrewStructuredTool,
|
||||||
|
calling: ToolCalling | InstructorToolCalling,
|
||||||
|
) -> str:
|
||||||
|
"""Internal async tool execution implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_string: The raw tool string from the agent.
|
||||||
|
tool: The tool to execute.
|
||||||
|
calling: The tool calling information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of the tool execution as a string.
|
||||||
|
"""
|
||||||
|
if self._check_tool_repeated_usage(calling=calling):
|
||||||
|
try:
|
||||||
|
result = self._i18n.errors("task_repeated_usage").format(
|
||||||
|
tool_names=self.tools_names
|
||||||
|
)
|
||||||
|
self._telemetry.tool_repeated_usage(
|
||||||
|
llm=self.function_calling_llm,
|
||||||
|
tool_name=tool.name,
|
||||||
|
attempts=self._run_attempts,
|
||||||
|
)
|
||||||
|
return self._format_result(result=result)
|
||||||
|
except Exception:
|
||||||
|
if self.task:
|
||||||
|
self.task.increment_tools_errors()
|
||||||
|
|
||||||
|
if self.agent:
|
||||||
|
event_data = {
|
||||||
|
"agent_key": self.agent.key,
|
||||||
|
"agent_role": self.agent.role,
|
||||||
|
"tool_name": self.action.tool,
|
||||||
|
"tool_args": self.action.tool_input,
|
||||||
|
"tool_class": self.action.tool,
|
||||||
|
"agent": self.agent,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.agent.fingerprint: # type: ignore
|
||||||
|
event_data.update(self.agent.fingerprint) # type: ignore
|
||||||
|
if self.task:
|
||||||
|
event_data["task_name"] = self.task.name or self.task.description
|
||||||
|
event_data["task_id"] = str(self.task.id)
|
||||||
|
crewai_event_bus.emit(self, ToolUsageStartedEvent(**event_data))
|
||||||
|
|
||||||
|
started_at = time.time()
|
||||||
|
from_cache = False
|
||||||
|
result = None # type: ignore
|
||||||
|
|
||||||
|
if self.tools_handler and self.tools_handler.cache:
|
||||||
|
input_str = ""
|
||||||
|
if calling.arguments:
|
||||||
|
if isinstance(calling.arguments, dict):
|
||||||
|
input_str = json.dumps(calling.arguments)
|
||||||
|
else:
|
||||||
|
input_str = str(calling.arguments)
|
||||||
|
|
||||||
|
result = self.tools_handler.cache.read(
|
||||||
|
tool=calling.tool_name, input=input_str
|
||||||
|
) # type: ignore
|
||||||
|
from_cache = result is not None
|
||||||
|
|
||||||
|
available_tool = next(
|
||||||
|
(
|
||||||
|
available_tool
|
||||||
|
for available_tool in self.tools
|
||||||
|
if available_tool.name == tool.name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
usage_limit_error = self._check_usage_limit(available_tool, tool.name)
|
||||||
|
if usage_limit_error:
|
||||||
|
try:
|
||||||
|
result = usage_limit_error
|
||||||
|
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||||
|
return self._format_result(result=result)
|
||||||
|
except Exception:
|
||||||
|
if self.task:
|
||||||
|
self.task.increment_tools_errors()
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
try:
|
||||||
|
if calling.tool_name in [
|
||||||
|
"Delegate work to coworker",
|
||||||
|
"Ask question to coworker",
|
||||||
|
]:
|
||||||
|
coworker = (
|
||||||
|
calling.arguments.get("coworker") if calling.arguments else None
|
||||||
|
)
|
||||||
|
if self.task:
|
||||||
|
self.task.increment_delegations(coworker)
|
||||||
|
|
||||||
|
if calling.arguments:
|
||||||
|
try:
|
||||||
|
acceptable_args = tool.args_schema.model_json_schema()[
|
||||||
|
"properties"
|
||||||
|
].keys()
|
||||||
|
arguments = {
|
||||||
|
k: v
|
||||||
|
for k, v in calling.arguments.items()
|
||||||
|
if k in acceptable_args
|
||||||
|
}
|
||||||
|
arguments = self._add_fingerprint_metadata(arguments)
|
||||||
|
result = await tool.ainvoke(input=arguments)
|
||||||
|
except Exception:
|
||||||
|
arguments = calling.arguments
|
||||||
|
arguments = self._add_fingerprint_metadata(arguments)
|
||||||
|
result = await tool.ainvoke(input=arguments)
|
||||||
|
else:
|
||||||
|
arguments = self._add_fingerprint_metadata({})
|
||||||
|
result = await tool.ainvoke(input=arguments)
|
||||||
|
except Exception as e:
|
||||||
|
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
||||||
|
self._run_attempts += 1
|
||||||
|
if self._run_attempts > self._max_parsing_attempts:
|
||||||
|
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||||
|
error_message = self._i18n.errors("tool_usage_exception").format(
|
||||||
|
error=e, tool=tool.name, tool_inputs=tool.description
|
||||||
|
)
|
||||||
|
error = ToolUsageError(
|
||||||
|
f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
|
||||||
|
).message
|
||||||
|
if self.task:
|
||||||
|
self.task.increment_tools_errors()
|
||||||
|
if self.agent and self.agent.verbose:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"\n\n{error_message}\n", color="red"
|
||||||
|
)
|
||||||
|
return error
|
||||||
|
|
||||||
|
if self.task:
|
||||||
|
self.task.increment_tools_errors()
|
||||||
|
return await self.ause(calling=calling, tool_string=tool_string)
|
||||||
|
|
||||||
|
if self.tools_handler:
|
||||||
|
should_cache = True
|
||||||
|
if (
|
||||||
|
hasattr(available_tool, "cache_function")
|
||||||
|
and available_tool.cache_function
|
||||||
|
):
|
||||||
|
should_cache = available_tool.cache_function(
|
||||||
|
calling.arguments, result
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tools_handler.on_tool_use(
|
||||||
|
calling=calling, output=result, should_cache=should_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
self._telemetry.tool_usage(
|
||||||
|
llm=self.function_calling_llm,
|
||||||
|
tool_name=tool.name,
|
||||||
|
attempts=self._run_attempts,
|
||||||
|
)
|
||||||
|
result = self._format_result(result=result)
|
||||||
|
data = {
|
||||||
|
"result": result,
|
||||||
|
"tool_name": tool.name,
|
||||||
|
"tool_args": calling.arguments,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.on_tool_use_finished(
|
||||||
|
tool=tool,
|
||||||
|
tool_calling=calling,
|
||||||
|
from_cache=from_cache,
|
||||||
|
started_at=started_at,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(available_tool, "result_as_answer")
|
||||||
|
and available_tool.result_as_answer # type: ignore
|
||||||
|
):
|
||||||
|
result_as_answer = available_tool.result_as_answer # type: ignore
|
||||||
|
data["result_as_answer"] = result_as_answer # type: ignore
|
||||||
|
|
||||||
|
if self.agent and hasattr(self.agent, "tools_results"):
|
||||||
|
self.agent.tools_results.append(data)
|
||||||
|
|
||||||
|
if available_tool and hasattr(available_tool, "current_usage_count"):
|
||||||
|
available_tool.current_usage_count += 1
|
||||||
|
if (
|
||||||
|
hasattr(available_tool, "max_usage_count")
|
||||||
|
and available_tool.max_usage_count is not None
|
||||||
|
):
|
||||||
|
self._printer.print(
|
||||||
|
content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
|
||||||
|
color="blue",
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def _use(
|
def _use(
|
||||||
self,
|
self,
|
||||||
tool_string: str,
|
tool_string: str,
|
||||||
|
|||||||
@@ -26,6 +26,138 @@ if TYPE_CHECKING:
|
|||||||
from crewai.task import Task
|
from crewai.task import Task
|
||||||
|
|
||||||
|
|
||||||
|
async def aexecute_tool_and_check_finality(
|
||||||
|
agent_action: AgentAction,
|
||||||
|
tools: list[CrewStructuredTool],
|
||||||
|
i18n: I18N,
|
||||||
|
agent_key: str | None = None,
|
||||||
|
agent_role: str | None = None,
|
||||||
|
tools_handler: ToolsHandler | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
agent: Agent | BaseAgent | None = None,
|
||||||
|
function_calling_llm: BaseLLM | LLM | None = None,
|
||||||
|
fingerprint_context: dict[str, str] | None = None,
|
||||||
|
crew: Crew | None = None,
|
||||||
|
) -> ToolResult:
|
||||||
|
"""Execute a tool asynchronously and check if the result should be a final answer.
|
||||||
|
|
||||||
|
This is the async version of execute_tool_and_check_finality. It integrates tool
|
||||||
|
hooks for before and after tool execution, allowing programmatic interception
|
||||||
|
and modification of tool calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_action: The action containing the tool to execute.
|
||||||
|
tools: List of available tools.
|
||||||
|
i18n: Internationalization settings.
|
||||||
|
agent_key: Optional key for event emission.
|
||||||
|
agent_role: Optional role for event emission.
|
||||||
|
tools_handler: Optional tools handler for tool execution.
|
||||||
|
task: Optional task for tool execution.
|
||||||
|
agent: Optional agent instance for tool execution.
|
||||||
|
function_calling_llm: Optional LLM for function calling.
|
||||||
|
fingerprint_context: Optional context for fingerprinting.
|
||||||
|
crew: Optional crew instance for hook context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolResult containing the execution result and whether it should be
|
||||||
|
treated as a final answer.
|
||||||
|
"""
|
||||||
|
logger = Logger(verbose=crew.verbose if crew else False)
|
||||||
|
tool_name_to_tool_map = {tool.name: tool for tool in tools}
|
||||||
|
|
||||||
|
if agent_key and agent_role and agent:
|
||||||
|
fingerprint_context = fingerprint_context or {}
|
||||||
|
if agent:
|
||||||
|
if hasattr(agent, "set_fingerprint") and callable(agent.set_fingerprint):
|
||||||
|
if isinstance(fingerprint_context, dict):
|
||||||
|
try:
|
||||||
|
fingerprint_obj = Fingerprint.from_dict(fingerprint_context)
|
||||||
|
agent.set_fingerprint(fingerprint=fingerprint_obj)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to set fingerprint: {e}") from e
|
||||||
|
|
||||||
|
tool_usage = ToolUsage(
|
||||||
|
tools_handler=tools_handler,
|
||||||
|
tools=tools,
|
||||||
|
function_calling_llm=function_calling_llm, # type: ignore[arg-type]
|
||||||
|
task=task,
|
||||||
|
agent=agent,
|
||||||
|
action=agent_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calling = tool_usage.parse_tool_calling(agent_action.text)
|
||||||
|
|
||||||
|
if isinstance(tool_calling, ToolUsageError):
|
||||||
|
return ToolResult(tool_calling.message, False)
|
||||||
|
|
||||||
|
if tool_calling.tool_name.casefold().strip() in [
|
||||||
|
name.casefold().strip() for name in tool_name_to_tool_map
|
||||||
|
] or tool_calling.tool_name.casefold().replace("_", " ") in [
|
||||||
|
name.casefold().strip() for name in tool_name_to_tool_map
|
||||||
|
]:
|
||||||
|
tool = tool_name_to_tool_map.get(tool_calling.tool_name)
|
||||||
|
if not tool:
|
||||||
|
tool_result = i18n.errors("wrong_tool_name").format(
|
||||||
|
tool=tool_calling.tool_name,
|
||||||
|
tools=", ".join([t.name.casefold() for t in tools]),
|
||||||
|
)
|
||||||
|
return ToolResult(result=tool_result, result_as_answer=False)
|
||||||
|
|
||||||
|
tool_input = tool_calling.arguments if tool_calling.arguments else {}
|
||||||
|
hook_context = ToolCallHookContext(
|
||||||
|
tool_name=tool_calling.tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
tool=tool,
|
||||||
|
agent=agent,
|
||||||
|
task=task,
|
||||||
|
crew=crew,
|
||||||
|
)
|
||||||
|
|
||||||
|
before_hooks = get_before_tool_call_hooks()
|
||||||
|
try:
|
||||||
|
for hook in before_hooks:
|
||||||
|
result = hook(hook_context)
|
||||||
|
if result is False:
|
||||||
|
blocked_message = (
|
||||||
|
f"Tool execution blocked by hook. "
|
||||||
|
f"Tool: {tool_calling.tool_name}"
|
||||||
|
)
|
||||||
|
return ToolResult(blocked_message, False)
|
||||||
|
except Exception as e:
|
||||||
|
logger.log("error", f"Error in before_tool_call hook: {e}")
|
||||||
|
|
||||||
|
tool_result = await tool_usage.ause(tool_calling, agent_action.text)
|
||||||
|
|
||||||
|
after_hook_context = ToolCallHookContext(
|
||||||
|
tool_name=tool_calling.tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
tool=tool,
|
||||||
|
agent=agent,
|
||||||
|
task=task,
|
||||||
|
crew=crew,
|
||||||
|
tool_result=tool_result,
|
||||||
|
)
|
||||||
|
|
||||||
|
after_hooks = get_after_tool_call_hooks()
|
||||||
|
modified_result: str = tool_result
|
||||||
|
try:
|
||||||
|
for after_hook in after_hooks:
|
||||||
|
hook_result = after_hook(after_hook_context)
|
||||||
|
if hook_result is not None:
|
||||||
|
modified_result = hook_result
|
||||||
|
after_hook_context.tool_result = modified_result
|
||||||
|
except Exception as e:
|
||||||
|
logger.log("error", f"Error in after_tool_call hook: {e}")
|
||||||
|
|
||||||
|
return ToolResult(modified_result, tool.result_as_answer)
|
||||||
|
|
||||||
|
tool_result = i18n.errors("wrong_tool_name").format(
|
||||||
|
tool=tool_calling.tool_name,
|
||||||
|
tools=", ".join([tool.name.casefold() for tool in tools]),
|
||||||
|
)
|
||||||
|
return ToolResult(result=tool_result, result_as_answer=False)
|
||||||
|
|
||||||
|
|
||||||
def execute_tool_and_check_finality(
|
def execute_tool_and_check_finality(
|
||||||
agent_action: AgentAction,
|
agent_action: AgentAction,
|
||||||
tools: list[CrewStructuredTool],
|
tools: list[CrewStructuredTool],
|
||||||
@@ -141,10 +273,10 @@ def execute_tool_and_check_finality(
|
|||||||
|
|
||||||
# Execute after_tool_call hooks
|
# Execute after_tool_call hooks
|
||||||
after_hooks = get_after_tool_call_hooks()
|
after_hooks = get_after_tool_call_hooks()
|
||||||
modified_result = tool_result
|
modified_result: str = tool_result
|
||||||
try:
|
try:
|
||||||
for hook in after_hooks:
|
for after_hook in after_hooks:
|
||||||
hook_result = hook(after_hook_context)
|
hook_result = after_hook(after_hook_context)
|
||||||
if hook_result is not None:
|
if hook_result is not None:
|
||||||
modified_result = hook_result
|
modified_result = hook_result
|
||||||
after_hook_context.tool_result = modified_result
|
after_hook_context.tool_result = modified_result
|
||||||
|
|||||||
196
lib/crewai/tests/tools/test_async_tools.py
Normal file
196
lib/crewai/tests/tools/test_async_tools.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""Tests for async tool functionality."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool, tool
|
||||||
|
|
||||||
|
|
||||||
|
class SyncTool(BaseTool):
|
||||||
|
"""Test tool with synchronous _run method."""
|
||||||
|
|
||||||
|
name: str = "sync_tool"
|
||||||
|
description: str = "A synchronous tool for testing"
|
||||||
|
|
||||||
|
def _run(self, input_text: str) -> str:
|
||||||
|
"""Process input text synchronously."""
|
||||||
|
return f"Sync processed: {input_text}"
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTool(BaseTool):
|
||||||
|
"""Test tool with both sync and async implementations."""
|
||||||
|
|
||||||
|
name: str = "async_tool"
|
||||||
|
description: str = "An asynchronous tool for testing"
|
||||||
|
|
||||||
|
def _run(self, input_text: str) -> str:
|
||||||
|
"""Process input text synchronously."""
|
||||||
|
return f"Sync processed: {input_text}"
|
||||||
|
|
||||||
|
async def _arun(self, input_text: str) -> str:
|
||||||
|
"""Process input text asynchronously."""
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return f"Async processed: {input_text}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseTool:
|
||||||
|
"""Tests for BaseTool async functionality."""
|
||||||
|
|
||||||
|
def test_sync_tool_run_returns_result(self) -> None:
|
||||||
|
"""Test that sync tool run() returns correct result."""
|
||||||
|
tool = SyncTool()
|
||||||
|
result = tool.run(input_text="hello")
|
||||||
|
assert result == "Sync processed: hello"
|
||||||
|
|
||||||
|
def test_async_tool_run_returns_result(self) -> None:
|
||||||
|
"""Test that async tool run() works."""
|
||||||
|
tool = AsyncTool()
|
||||||
|
result = tool.run(input_text="hello")
|
||||||
|
assert result == "Sync processed: hello"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sync_tool_arun_raises_not_implemented(self) -> None:
|
||||||
|
"""Test that sync tool arun() raises NotImplementedError."""
|
||||||
|
tool = SyncTool()
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
await tool.arun(input_text="hello")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_tool_arun_returns_result(self) -> None:
|
||||||
|
"""Test that async tool arun() awaits directly."""
|
||||||
|
tool = AsyncTool()
|
||||||
|
result = await tool.arun(input_text="hello")
|
||||||
|
assert result == "Async processed: hello"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_arun_increments_usage_count(self) -> None:
|
||||||
|
"""Test that arun increments the usage count."""
|
||||||
|
tool = AsyncTool()
|
||||||
|
assert tool.current_usage_count == 0
|
||||||
|
|
||||||
|
await tool.arun(input_text="test")
|
||||||
|
assert tool.current_usage_count == 1
|
||||||
|
|
||||||
|
await tool.arun(input_text="test2")
|
||||||
|
assert tool.current_usage_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_async_tools_run_concurrently(self) -> None:
|
||||||
|
"""Test that multiple async tools can run concurrently."""
|
||||||
|
tool1 = AsyncTool()
|
||||||
|
tool2 = AsyncTool()
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
tool1.arun(input_text="first"),
|
||||||
|
tool2.arun(input_text="second"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert results[0] == "Async processed: first"
|
||||||
|
assert results[1] == "Async processed: second"
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolDecorator:
|
||||||
|
"""Tests for @tool decorator with async functions."""
|
||||||
|
|
||||||
|
def test_sync_decorated_tool_run(self) -> None:
|
||||||
|
"""Test sync decorated tool works with run()."""
|
||||||
|
|
||||||
|
@tool("sync_decorated")
|
||||||
|
def sync_func(value: str) -> str:
|
||||||
|
"""A sync decorated tool."""
|
||||||
|
return f"sync: {value}"
|
||||||
|
|
||||||
|
result = sync_func.run(value="test")
|
||||||
|
assert result == "sync: test"
|
||||||
|
|
||||||
|
def test_async_decorated_tool_run(self) -> None:
|
||||||
|
"""Test async decorated tool works with run()."""
|
||||||
|
|
||||||
|
@tool("async_decorated")
|
||||||
|
async def async_func(value: str) -> str:
|
||||||
|
"""An async decorated tool."""
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return f"async: {value}"
|
||||||
|
|
||||||
|
result = async_func.run(value="test")
|
||||||
|
assert result == "async: test"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sync_decorated_tool_arun_raises(self) -> None:
|
||||||
|
"""Test sync decorated tool arun() raises NotImplementedError."""
|
||||||
|
|
||||||
|
@tool("sync_decorated_arun")
|
||||||
|
def sync_func(value: str) -> str:
|
||||||
|
"""A sync decorated tool."""
|
||||||
|
return f"sync: {value}"
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
await sync_func.arun(value="test")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_decorated_tool_arun(self) -> None:
|
||||||
|
"""Test async decorated tool works with arun()."""
|
||||||
|
|
||||||
|
@tool("async_decorated_arun")
|
||||||
|
async def async_func(value: str) -> str:
|
||||||
|
"""An async decorated tool."""
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return f"async: {value}"
|
||||||
|
|
||||||
|
result = await async_func.arun(value="test")
|
||||||
|
assert result == "async: test"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncToolWithIO:
|
||||||
|
"""Tests for async tools with simulated I/O operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_tool_simulated_io(self) -> None:
|
||||||
|
"""Test async tool with simulated I/O delay."""
|
||||||
|
|
||||||
|
class SlowAsyncTool(BaseTool):
|
||||||
|
name: str = "slow_async"
|
||||||
|
description: str = "Simulates slow I/O"
|
||||||
|
|
||||||
|
def _run(self, delay: float) -> str:
|
||||||
|
return f"Completed after {delay}s"
|
||||||
|
|
||||||
|
async def _arun(self, delay: float) -> str:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
return f"Completed after {delay}s"
|
||||||
|
|
||||||
|
tool = SlowAsyncTool()
|
||||||
|
result = await tool.arun(delay=0.05)
|
||||||
|
assert result == "Completed after 0.05s"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_slow_tools_concurrent(self) -> None:
|
||||||
|
"""Test that slow async tools benefit from concurrency."""
|
||||||
|
|
||||||
|
class SlowAsyncTool(BaseTool):
|
||||||
|
name: str = "slow_async"
|
||||||
|
description: str = "Simulates slow I/O"
|
||||||
|
|
||||||
|
def _run(self, task_id: int, delay: float) -> str:
|
||||||
|
return f"Task {task_id} done"
|
||||||
|
|
||||||
|
async def _arun(self, task_id: int, delay: float) -> str:
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
return f"Task {task_id} done"
|
||||||
|
|
||||||
|
tool = SlowAsyncTool()
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
results = await asyncio.gather(
|
||||||
|
tool.arun(task_id=1, delay=0.1),
|
||||||
|
tool.arun(task_id=2, delay=0.1),
|
||||||
|
tool.arun(task_id=3, delay=0.1),
|
||||||
|
)
|
||||||
|
elapsed = time.time() - start
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
assert all("done" in r for r in results)
|
||||||
|
assert elapsed < 0.25, f"Expected concurrent execution, took {elapsed}s"
|
||||||
Reference in New Issue
Block a user