diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 38ddaa7ab..aedda4d54 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -2,9 +2,18 @@ from __future__ import annotations from abc import ABC, abstractmethod import asyncio -from collections.abc import Callable +from collections.abc import Awaitable, Callable from inspect import signature -from typing import Any, cast, get_args, get_origin +from typing import ( + Any, + Generic, + ParamSpec, + TypeVar, + cast, + get_args, + get_origin, + overload, +) from pydantic import ( BaseModel, @@ -14,6 +23,7 @@ from pydantic import ( create_model, field_validator, ) +from typing_extensions import TypeIs from crewai.tools.structured_tool import CrewStructuredTool from crewai.utilities.printer import Printer @@ -21,12 +31,20 @@ from crewai.utilities.printer import Printer _printer = Printer() +P = ParamSpec("P") +R = TypeVar("R", covariant=True) + def _is_async_callable(func: Callable[..., Any]) -> bool: """Check if a callable is async.""" return asyncio.iscoroutinefunction(func) +def _is_awaitable(value: R | Awaitable[R]) -> TypeIs[Awaitable[R]]: + """Type narrowing check for awaitable values.""" + return asyncio.iscoroutine(value) or asyncio.isfuture(value) + + class EnvVar(BaseModel): name: str description: str @@ -288,29 +306,55 @@ class BaseTool(BaseModel, ABC): return str(origin.__name__) -class Tool(BaseTool): +class Tool(BaseTool, Generic[P, R]): """Tool that wraps a callable function. - The function can be either synchronous or asynchronous. + + Type Parameters: + P: ParamSpec capturing the function's parameters. + R: The return type of the function. """ - func: Callable[..., Any] + func: Callable[P, R | Awaitable[R]] - def _run(self, *args: Any, **kwargs: Any) -> Any: + def run(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Execute the tool synchronously with type-safe parameters.""" + _printer.print(f"Using Tool: {self.name}", color="cyan") + result = self.func(*args, **kwargs) + + if asyncio.iscoroutine(result): + result = asyncio.run(result) + + self.current_usage_count += 1 + return result # type: ignore[return-value] + + def _run(self, *args: P.args, **kwargs: P.kwargs) -> R: """Execute the wrapped function.""" - return self.func(*args, **kwargs) + result = self.func(*args, **kwargs) + if _is_awaitable(result): + raise NotImplementedError( + f"{self.name} is an async function. Use arun() for async execution." + ) + return result - async def _arun(self, *args: Any, **kwargs: Any) -> Any: + async def arun(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Execute the tool asynchronously with type-safe parameters.""" + result = await self._arun(*args, **kwargs) + self.current_usage_count += 1 + return result + + async def _arun(self, *args: P.args, **kwargs: P.kwargs) -> R: """Execute the wrapped function asynchronously.""" - if _is_async_callable(self.func): - return await self.func(*args, **kwargs) + result = self.func(*args, **kwargs) + if _is_awaitable(result): + return await result raise NotImplementedError( f"{self.name} does not have an async function. " "Use run() for sync execution or provide an async function." ) @classmethod - def from_langchain(cls, tool: Any) -> Tool: + def from_langchain(cls, tool: Any) -> Tool[..., Any]: """Create a Tool instance from a CrewStructuredTool. This method takes a CrewStructuredTool object and converts it into a @@ -318,10 +362,10 @@ class Tool(BaseTool): attribute and infers the argument schema if not explicitly provided. Args: - tool (Any): The CrewStructuredTool object to be converted. + tool: The CrewStructuredTool object to be converted. Returns: - Tool: A new Tool instance created from the provided CrewStructuredTool. + A new Tool instance created from the provided CrewStructuredTool. Raises: ValueError: If the provided tool does not have a callable 'func' attribute. @@ -365,41 +409,83 @@ class Tool(BaseTool): def to_langchain( tools: list[BaseTool | CrewStructuredTool], ) -> list[CrewStructuredTool]: + """Convert a list of tools to CrewStructuredTool instances.""" return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools] +P2 = ParamSpec("P2") +R2 = TypeVar("R2") + + +@overload +def tool(func: Callable[P2, R2], /) -> Tool[P2, R2]: ... + + +@overload def tool( - *args: Callable[..., Any] | str, + name: str, + /, + *, + result_as_answer: bool = ..., + max_usage_count: int | None = ..., +) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ... + + +@overload +def tool( + *, + result_as_answer: bool = ..., + max_usage_count: int | None = ..., +) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ... + + +def tool( + *args: Callable[P2, R2] | str, result_as_answer: bool = False, max_usage_count: int | None = None, -) -> Callable[[Callable[..., Any]], Tool] | Tool: - """Decorator to create a tool from a function. +) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]: + """Decorator to create a Tool from a function. + + Can be used in three ways: + 1. @tool - decorator without arguments, uses function name + 2. @tool("name") - decorator with custom name + 3. @tool(result_as_answer=True) - decorator with options Args: - *args: Positional arguments, either the function to decorate or the tool name. - result_as_answer: Flag to indicate if the tool result should be used as the final agent answer. - max_usage_count: Maximum number of times this tool can be used. None means unlimited usage. + *args: Either the function to decorate or a custom tool name. + result_as_answer: If True, the tool result becomes the final agent answer. + max_usage_count: Maximum times this tool can be used. None means unlimited. Returns: - A Tool instance or a decorator that creates a Tool instance. + A Tool instance. + + Example: + @tool + def greet(name: str) -> str: + '''Greet someone.''' + return f"Hello, {name}!" + + result = greet.run("World") """ - def _make_with_name(tool_name: str) -> Callable[[Callable[..., Any]], Tool]: - def _make_tool(f: Callable[..., Any]) -> Tool: + def _make_with_name(tool_name: str) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: + def _make_tool(f: Callable[P2, R2]) -> Tool[P2, R2]: if f.__doc__ is None: raise ValueError("Function must have a docstring") - if f.__annotations__ is None: + + func_annotations = getattr(f, "__annotations__", None) + if func_annotations is None: raise ValueError("Function must have type annotations") class_name = "".join(tool_name.split()).title() - args_schema = cast( + tool_args_schema = cast( type[PydanticBaseModel], type( class_name, (PydanticBaseModel,), { "__annotations__": { - k: v for k, v in f.__annotations__.items() if k != "return" + k: v for k, v in func_annotations.items() if k != "return" }, }, ), @@ -409,10 +495,9 @@ def tool( name=tool_name, description=f.__doc__, func=f, - args_schema=args_schema, + args_schema=tool_args_schema, result_as_answer=result_as_answer, max_usage_count=max_usage_count, - current_usage_count=0, ) return _make_tool @@ -421,4 +506,10 @@ def tool( return _make_with_name(args[0].__name__)(args[0]) if len(args) == 1 and isinstance(args[0], str): return _make_with_name(args[0]) + if len(args) == 0: + + def decorator(f: Callable[P2, R2]) -> Tool[P2, R2]: + return _make_with_name(f.__name__)(f) + + return decorator raise ValueError("Invalid arguments")