From be33c8e3e56b8a90e7c2acad0c9df925d7c1d54b Mon Sep 17 00:00:00 2001 From: Greyson Lalonde Date: Tue, 2 Dec 2025 00:03:28 -0500 Subject: [PATCH] feat: add async support for tools, add async tool tests --- lib/crewai/src/crewai/tools/base_tool.py | 85 +++++- lib/crewai/src/crewai/tools/tool_usage.py | 245 ++++++++++++++++++ lib/crewai/src/crewai/utilities/tool_utils.py | 138 +++++++++- lib/crewai/tests/tools/test_async_tools.py | 196 ++++++++++++++ 4 files changed, 649 insertions(+), 15 deletions(-) create mode 100644 lib/crewai/tests/tools/test_async_tools.py diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 19ed6b671..38ddaa7ab 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -22,6 +22,11 @@ from crewai.utilities.printer import Printer _printer = Printer() +def _is_async_callable(func: Callable[..., Any]) -> bool: + """Check if a callable is async.""" + return asyncio.iscoroutinefunction(func) + + class EnvVar(BaseModel): name: str description: str @@ -55,7 +60,7 @@ class BaseTool(BaseModel, ABC): default=False, description="Flag to check if the description has been updated." ) - cache_function: Callable = Field( + cache_function: Callable[..., bool] = Field( default=lambda _args=None, _result=None: True, description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.", ) @@ -123,6 +128,35 @@ class BaseTool(BaseModel, ABC): return result + async def arun( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Execute the tool asynchronously. + + Args: + *args: Positional arguments to pass to the tool. + **kwargs: Keyword arguments to pass to the tool. + + Returns: + The result of the tool execution. + """ + result = await self._arun(*args, **kwargs) + self.current_usage_count += 1 + return result + + async def _arun( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Async implementation of the tool. Override for async support.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not implement _arun. " + "Override _arun for async support or use run() for sync execution." + ) + def reset_usage_count(self) -> None: """Reset the current usage count to zero.""" self.current_usage_count = 0 @@ -133,7 +167,17 @@ class BaseTool(BaseModel, ABC): *args: Any, **kwargs: Any, ) -> Any: - """Here goes the actual implementation of the tool.""" + """Sync implementation of the tool. + + Subclasses must implement this method for synchronous execution. + + Args: + *args: Positional arguments for the tool. + **kwargs: Keyword arguments for the tool. + + Returns: + The result of the tool execution. + """ def to_structured_tool(self) -> CrewStructuredTool: """Convert this tool to a CrewStructuredTool instance.""" @@ -239,19 +283,32 @@ class BaseTool(BaseModel, ABC): if args: args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args) - return f"{origin.__name__}[{args_str}]" + return str(f"{origin.__name__}[{args_str}]") - return origin.__name__ + return str(origin.__name__) class Tool(BaseTool): - """The function that will be executed when the tool is called.""" + """Tool that wraps a callable function. - func: Callable + The function can be either synchronous or asynchronous. + """ + + func: Callable[..., Any] def _run(self, *args: Any, **kwargs: Any) -> Any: + """Execute the wrapped function.""" return self.func(*args, **kwargs) + async def _arun(self, *args: Any, **kwargs: Any) -> Any: + """Execute the wrapped function asynchronously.""" + if _is_async_callable(self.func): + return await self.func(*args, **kwargs) + raise NotImplementedError( + f"{self.name} does not have an async function. " + "Use run() for sync execution or provide an async function." + ) + @classmethod def from_langchain(cls, tool: Any) -> Tool: """Create a Tool instance from a CrewStructuredTool. @@ -312,19 +369,23 @@ def to_langchain( def tool( - *args, result_as_answer: bool = False, max_usage_count: int | None = None -) -> Callable: - """ - Decorator to create a tool from a function. + *args: Callable[..., Any] | str, + result_as_answer: bool = False, + max_usage_count: int | None = None, +) -> Callable[[Callable[..., Any]], Tool] | Tool: + """Decorator to create a tool from a function. Args: *args: Positional arguments, either the function to decorate or the tool name. result_as_answer: Flag to indicate if the tool result should be used as the final agent answer. max_usage_count: Maximum number of times this tool can be used. None means unlimited usage. + + Returns: + A Tool instance or a decorator that creates a Tool instance. """ - def _make_with_name(tool_name: str) -> Callable: - def _make_tool(f: Callable) -> BaseTool: + def _make_with_name(tool_name: str) -> Callable[[Callable[..., Any]], Tool]: + def _make_tool(f: Callable[..., Any]) -> Tool: if f.__doc__ is None: raise ValueError("Function must have a docstring") if f.__annotations__ is None: diff --git a/lib/crewai/src/crewai/tools/tool_usage.py b/lib/crewai/src/crewai/tools/tool_usage.py index 6f0e92cb8..8f753f412 100644 --- a/lib/crewai/src/crewai/tools/tool_usage.py +++ b/lib/crewai/src/crewai/tools/tool_usage.py @@ -160,6 +160,251 @@ class ToolUsage: return f"{self._use(tool_string=tool_string, tool=tool, calling=calling)}" + async def ause( + self, calling: ToolCalling | InstructorToolCalling, tool_string: str + ) -> str: + """Execute a tool asynchronously. + + Args: + calling: The tool calling information. + tool_string: The raw tool string from the agent. + + Returns: + The result of the tool execution as a string. + """ + if isinstance(calling, ToolUsageError): + error = calling.message + if self.agent and self.agent.verbose: + self._printer.print(content=f"\n\n{error}\n", color="red") + if self.task: + self.task.increment_tools_errors() + return error + + try: + tool = self._select_tool(calling.tool_name) + except Exception as e: + error = getattr(e, "message", str(e)) + if self.task: + self.task.increment_tools_errors() + if self.agent and self.agent.verbose: + self._printer.print(content=f"\n\n{error}\n", color="red") + return error + + if ( + isinstance(tool, CrewStructuredTool) + and tool.name == self._i18n.tools("add_image")["name"] # type: ignore + ): + try: + return await self._ause( + tool_string=tool_string, tool=tool, calling=calling + ) + except Exception as e: + error = getattr(e, "message", str(e)) + if self.task: + self.task.increment_tools_errors() + if self.agent and self.agent.verbose: + self._printer.print(content=f"\n\n{error}\n", color="red") + return error + + return ( + f"{await self._ause(tool_string=tool_string, tool=tool, calling=calling)}" + ) + + async def _ause( + self, + tool_string: str, + tool: CrewStructuredTool, + calling: ToolCalling | InstructorToolCalling, + ) -> str: + """Internal async tool execution implementation. + + Args: + tool_string: The raw tool string from the agent. + tool: The tool to execute. + calling: The tool calling information. + + Returns: + The result of the tool execution as a string. + """ + if self._check_tool_repeated_usage(calling=calling): + try: + result = self._i18n.errors("task_repeated_usage").format( + tool_names=self.tools_names + ) + self._telemetry.tool_repeated_usage( + llm=self.function_calling_llm, + tool_name=tool.name, + attempts=self._run_attempts, + ) + return self._format_result(result=result) + except Exception: + if self.task: + self.task.increment_tools_errors() + + if self.agent: + event_data = { + "agent_key": self.agent.key, + "agent_role": self.agent.role, + "tool_name": self.action.tool, + "tool_args": self.action.tool_input, + "tool_class": self.action.tool, + "agent": self.agent, + } + + if self.agent.fingerprint: # type: ignore + event_data.update(self.agent.fingerprint) # type: ignore + if self.task: + event_data["task_name"] = self.task.name or self.task.description + event_data["task_id"] = str(self.task.id) + crewai_event_bus.emit(self, ToolUsageStartedEvent(**event_data)) + + started_at = time.time() + from_cache = False + result = None # type: ignore + + if self.tools_handler and self.tools_handler.cache: + input_str = "" + if calling.arguments: + if isinstance(calling.arguments, dict): + input_str = json.dumps(calling.arguments) + else: + input_str = str(calling.arguments) + + result = self.tools_handler.cache.read( + tool=calling.tool_name, input=input_str + ) # type: ignore + from_cache = result is not None + + available_tool = next( + ( + available_tool + for available_tool in self.tools + if available_tool.name == tool.name + ), + None, + ) + + usage_limit_error = self._check_usage_limit(available_tool, tool.name) + if usage_limit_error: + try: + result = usage_limit_error + self._telemetry.tool_usage_error(llm=self.function_calling_llm) + return self._format_result(result=result) + except Exception: + if self.task: + self.task.increment_tools_errors() + + if result is None: + try: + if calling.tool_name in [ + "Delegate work to coworker", + "Ask question to coworker", + ]: + coworker = ( + calling.arguments.get("coworker") if calling.arguments else None + ) + if self.task: + self.task.increment_delegations(coworker) + + if calling.arguments: + try: + acceptable_args = tool.args_schema.model_json_schema()[ + "properties" + ].keys() + arguments = { + k: v + for k, v in calling.arguments.items() + if k in acceptable_args + } + arguments = self._add_fingerprint_metadata(arguments) + result = await tool.ainvoke(input=arguments) + except Exception: + arguments = calling.arguments + arguments = self._add_fingerprint_metadata(arguments) + result = await tool.ainvoke(input=arguments) + else: + arguments = self._add_fingerprint_metadata({}) + result = await tool.ainvoke(input=arguments) + except Exception as e: + self.on_tool_error(tool=tool, tool_calling=calling, e=e) + self._run_attempts += 1 + if self._run_attempts > self._max_parsing_attempts: + self._telemetry.tool_usage_error(llm=self.function_calling_llm) + error_message = self._i18n.errors("tool_usage_exception").format( + error=e, tool=tool.name, tool_inputs=tool.description + ) + error = ToolUsageError( + f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}" + ).message + if self.task: + self.task.increment_tools_errors() + if self.agent and self.agent.verbose: + self._printer.print( + content=f"\n\n{error_message}\n", color="red" + ) + return error + + if self.task: + self.task.increment_tools_errors() + return await self.ause(calling=calling, tool_string=tool_string) + + if self.tools_handler: + should_cache = True + if ( + hasattr(available_tool, "cache_function") + and available_tool.cache_function + ): + should_cache = available_tool.cache_function( + calling.arguments, result + ) + + self.tools_handler.on_tool_use( + calling=calling, output=result, should_cache=should_cache + ) + + self._telemetry.tool_usage( + llm=self.function_calling_llm, + tool_name=tool.name, + attempts=self._run_attempts, + ) + result = self._format_result(result=result) + data = { + "result": result, + "tool_name": tool.name, + "tool_args": calling.arguments, + } + + self.on_tool_use_finished( + tool=tool, + tool_calling=calling, + from_cache=from_cache, + started_at=started_at, + result=result, + ) + + if ( + hasattr(available_tool, "result_as_answer") + and available_tool.result_as_answer # type: ignore + ): + result_as_answer = available_tool.result_as_answer # type: ignore + data["result_as_answer"] = result_as_answer # type: ignore + + if self.agent and hasattr(self.agent, "tools_results"): + self.agent.tools_results.append(data) + + if available_tool and hasattr(available_tool, "current_usage_count"): + available_tool.current_usage_count += 1 + if ( + hasattr(available_tool, "max_usage_count") + and available_tool.max_usage_count is not None + ): + self._printer.print( + content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}", + color="blue", + ) + + return result + def _use( self, tool_string: str, diff --git a/lib/crewai/src/crewai/utilities/tool_utils.py b/lib/crewai/src/crewai/utilities/tool_utils.py index aac2b979c..ca588f699 100644 --- a/lib/crewai/src/crewai/utilities/tool_utils.py +++ b/lib/crewai/src/crewai/utilities/tool_utils.py @@ -26,6 +26,138 @@ if TYPE_CHECKING: from crewai.task import Task +async def aexecute_tool_and_check_finality( + agent_action: AgentAction, + tools: list[CrewStructuredTool], + i18n: I18N, + agent_key: str | None = None, + agent_role: str | None = None, + tools_handler: ToolsHandler | None = None, + task: Task | None = None, + agent: Agent | BaseAgent | None = None, + function_calling_llm: BaseLLM | LLM | None = None, + fingerprint_context: dict[str, str] | None = None, + crew: Crew | None = None, +) -> ToolResult: + """Execute a tool asynchronously and check if the result should be a final answer. + + This is the async version of execute_tool_and_check_finality. It integrates tool + hooks for before and after tool execution, allowing programmatic interception + and modification of tool calls. + + Args: + agent_action: The action containing the tool to execute. + tools: List of available tools. + i18n: Internationalization settings. + agent_key: Optional key for event emission. + agent_role: Optional role for event emission. + tools_handler: Optional tools handler for tool execution. + task: Optional task for tool execution. + agent: Optional agent instance for tool execution. + function_calling_llm: Optional LLM for function calling. + fingerprint_context: Optional context for fingerprinting. + crew: Optional crew instance for hook context. + + Returns: + ToolResult containing the execution result and whether it should be + treated as a final answer. + """ + logger = Logger(verbose=crew.verbose if crew else False) + tool_name_to_tool_map = {tool.name: tool for tool in tools} + + if agent_key and agent_role and agent: + fingerprint_context = fingerprint_context or {} + if agent: + if hasattr(agent, "set_fingerprint") and callable(agent.set_fingerprint): + if isinstance(fingerprint_context, dict): + try: + fingerprint_obj = Fingerprint.from_dict(fingerprint_context) + agent.set_fingerprint(fingerprint=fingerprint_obj) + except Exception as e: + raise ValueError(f"Failed to set fingerprint: {e}") from e + + tool_usage = ToolUsage( + tools_handler=tools_handler, + tools=tools, + function_calling_llm=function_calling_llm, # type: ignore[arg-type] + task=task, + agent=agent, + action=agent_action, + ) + + tool_calling = tool_usage.parse_tool_calling(agent_action.text) + + if isinstance(tool_calling, ToolUsageError): + return ToolResult(tool_calling.message, False) + + if tool_calling.tool_name.casefold().strip() in [ + name.casefold().strip() for name in tool_name_to_tool_map + ] or tool_calling.tool_name.casefold().replace("_", " ") in [ + name.casefold().strip() for name in tool_name_to_tool_map + ]: + tool = tool_name_to_tool_map.get(tool_calling.tool_name) + if not tool: + tool_result = i18n.errors("wrong_tool_name").format( + tool=tool_calling.tool_name, + tools=", ".join([t.name.casefold() for t in tools]), + ) + return ToolResult(result=tool_result, result_as_answer=False) + + tool_input = tool_calling.arguments if tool_calling.arguments else {} + hook_context = ToolCallHookContext( + tool_name=tool_calling.tool_name, + tool_input=tool_input, + tool=tool, + agent=agent, + task=task, + crew=crew, + ) + + before_hooks = get_before_tool_call_hooks() + try: + for hook in before_hooks: + result = hook(hook_context) + if result is False: + blocked_message = ( + f"Tool execution blocked by hook. " + f"Tool: {tool_calling.tool_name}" + ) + return ToolResult(blocked_message, False) + except Exception as e: + logger.log("error", f"Error in before_tool_call hook: {e}") + + tool_result = await tool_usage.ause(tool_calling, agent_action.text) + + after_hook_context = ToolCallHookContext( + tool_name=tool_calling.tool_name, + tool_input=tool_input, + tool=tool, + agent=agent, + task=task, + crew=crew, + tool_result=tool_result, + ) + + after_hooks = get_after_tool_call_hooks() + modified_result: str = tool_result + try: + for after_hook in after_hooks: + hook_result = after_hook(after_hook_context) + if hook_result is not None: + modified_result = hook_result + after_hook_context.tool_result = modified_result + except Exception as e: + logger.log("error", f"Error in after_tool_call hook: {e}") + + return ToolResult(modified_result, tool.result_as_answer) + + tool_result = i18n.errors("wrong_tool_name").format( + tool=tool_calling.tool_name, + tools=", ".join([tool.name.casefold() for tool in tools]), + ) + return ToolResult(result=tool_result, result_as_answer=False) + + def execute_tool_and_check_finality( agent_action: AgentAction, tools: list[CrewStructuredTool], @@ -141,10 +273,10 @@ def execute_tool_and_check_finality( # Execute after_tool_call hooks after_hooks = get_after_tool_call_hooks() - modified_result = tool_result + modified_result: str = tool_result try: - for hook in after_hooks: - hook_result = hook(after_hook_context) + for after_hook in after_hooks: + hook_result = after_hook(after_hook_context) if hook_result is not None: modified_result = hook_result after_hook_context.tool_result = modified_result diff --git a/lib/crewai/tests/tools/test_async_tools.py b/lib/crewai/tests/tools/test_async_tools.py new file mode 100644 index 000000000..d95df3b39 --- /dev/null +++ b/lib/crewai/tests/tools/test_async_tools.py @@ -0,0 +1,196 @@ +"""Tests for async tool functionality.""" + +import asyncio + +import pytest + +from crewai.tools import BaseTool, tool + + +class SyncTool(BaseTool): + """Test tool with synchronous _run method.""" + + name: str = "sync_tool" + description: str = "A synchronous tool for testing" + + def _run(self, input_text: str) -> str: + """Process input text synchronously.""" + return f"Sync processed: {input_text}" + + +class AsyncTool(BaseTool): + """Test tool with both sync and async implementations.""" + + name: str = "async_tool" + description: str = "An asynchronous tool for testing" + + def _run(self, input_text: str) -> str: + """Process input text synchronously.""" + return f"Sync processed: {input_text}" + + async def _arun(self, input_text: str) -> str: + """Process input text asynchronously.""" + await asyncio.sleep(0.01) + return f"Async processed: {input_text}" + + +class TestBaseTool: + """Tests for BaseTool async functionality.""" + + def test_sync_tool_run_returns_result(self) -> None: + """Test that sync tool run() returns correct result.""" + tool = SyncTool() + result = tool.run(input_text="hello") + assert result == "Sync processed: hello" + + def test_async_tool_run_returns_result(self) -> None: + """Test that async tool run() works.""" + tool = AsyncTool() + result = tool.run(input_text="hello") + assert result == "Sync processed: hello" + + @pytest.mark.asyncio + async def test_sync_tool_arun_raises_not_implemented(self) -> None: + """Test that sync tool arun() raises NotImplementedError.""" + tool = SyncTool() + with pytest.raises(NotImplementedError): + await tool.arun(input_text="hello") + + @pytest.mark.asyncio + async def test_async_tool_arun_returns_result(self) -> None: + """Test that async tool arun() awaits directly.""" + tool = AsyncTool() + result = await tool.arun(input_text="hello") + assert result == "Async processed: hello" + + @pytest.mark.asyncio + async def test_arun_increments_usage_count(self) -> None: + """Test that arun increments the usage count.""" + tool = AsyncTool() + assert tool.current_usage_count == 0 + + await tool.arun(input_text="test") + assert tool.current_usage_count == 1 + + await tool.arun(input_text="test2") + assert tool.current_usage_count == 2 + + @pytest.mark.asyncio + async def test_multiple_async_tools_run_concurrently(self) -> None: + """Test that multiple async tools can run concurrently.""" + tool1 = AsyncTool() + tool2 = AsyncTool() + + results = await asyncio.gather( + tool1.arun(input_text="first"), + tool2.arun(input_text="second"), + ) + + assert results[0] == "Async processed: first" + assert results[1] == "Async processed: second" + + +class TestToolDecorator: + """Tests for @tool decorator with async functions.""" + + def test_sync_decorated_tool_run(self) -> None: + """Test sync decorated tool works with run().""" + + @tool("sync_decorated") + def sync_func(value: str) -> str: + """A sync decorated tool.""" + return f"sync: {value}" + + result = sync_func.run(value="test") + assert result == "sync: test" + + def test_async_decorated_tool_run(self) -> None: + """Test async decorated tool works with run().""" + + @tool("async_decorated") + async def async_func(value: str) -> str: + """An async decorated tool.""" + await asyncio.sleep(0.01) + return f"async: {value}" + + result = async_func.run(value="test") + assert result == "async: test" + + @pytest.mark.asyncio + async def test_sync_decorated_tool_arun_raises(self) -> None: + """Test sync decorated tool arun() raises NotImplementedError.""" + + @tool("sync_decorated_arun") + def sync_func(value: str) -> str: + """A sync decorated tool.""" + return f"sync: {value}" + + with pytest.raises(NotImplementedError): + await sync_func.arun(value="test") + + @pytest.mark.asyncio + async def test_async_decorated_tool_arun(self) -> None: + """Test async decorated tool works with arun().""" + + @tool("async_decorated_arun") + async def async_func(value: str) -> str: + """An async decorated tool.""" + await asyncio.sleep(0.01) + return f"async: {value}" + + result = await async_func.arun(value="test") + assert result == "async: test" + + +class TestAsyncToolWithIO: + """Tests for async tools with simulated I/O operations.""" + + @pytest.mark.asyncio + async def test_async_tool_simulated_io(self) -> None: + """Test async tool with simulated I/O delay.""" + + class SlowAsyncTool(BaseTool): + name: str = "slow_async" + description: str = "Simulates slow I/O" + + def _run(self, delay: float) -> str: + return f"Completed after {delay}s" + + async def _arun(self, delay: float) -> str: + await asyncio.sleep(delay) + return f"Completed after {delay}s" + + tool = SlowAsyncTool() + result = await tool.arun(delay=0.05) + assert result == "Completed after 0.05s" + + @pytest.mark.asyncio + async def test_multiple_slow_tools_concurrent(self) -> None: + """Test that slow async tools benefit from concurrency.""" + + class SlowAsyncTool(BaseTool): + name: str = "slow_async" + description: str = "Simulates slow I/O" + + def _run(self, task_id: int, delay: float) -> str: + return f"Task {task_id} done" + + async def _arun(self, task_id: int, delay: float) -> str: + await asyncio.sleep(delay) + return f"Task {task_id} done" + + tool = SlowAsyncTool() + + import time + + start = time.time() + results = await asyncio.gather( + tool.arun(task_id=1, delay=0.1), + tool.arun(task_id=2, delay=0.1), + tool.arun(task_id=3, delay=0.1), + ) + elapsed = time.time() - start + + assert len(results) == 3 + assert all("done" in r for r in results) + assert elapsed < 0.25, f"Expected concurrent execution, took {elapsed}s" \ No newline at end of file