This commit is contained in:
Brandon Hancock
2025-03-13 15:45:11 -04:00
parent cb86594f92
commit 358befe2c1
2 changed files with 18 additions and 6 deletions

View File

@@ -248,13 +248,18 @@ def to_langchain(
def tool(*args):
"""
Decorator to create a tool from a function.
Ensures the decorated function is always wrapped as a BaseTool.
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_with_name(tool_name: str) -> Callable[[Callable], BaseTool]:
def _make_tool(f: Callable) -> BaseTool:
# If f is already a BaseTool, return it
if isinstance(f, BaseTool):
return f
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
if f.__annotations__ is None:
if not f.__annotations__:
raise ValueError("Function must have type annotations")
class_name = "".join(tool_name.split()).title()
@@ -278,7 +283,12 @@ def tool(*args):
return _make_tool
if len(args) == 1 and callable(args[0]):
# Direct function decoration
if isinstance(args[0], BaseTool):
return args[0] # Already a BaseTool, return as-is
return _make_with_name(args[0].__name__)(args[0])
if len(args) == 1 and isinstance(args[0], str):
elif len(args) == 1 and isinstance(args[0], str):
# Name provided, return a decorator
return _make_with_name(args[0])
raise ValueError("Invalid arguments")
else:
raise ValueError("Invalid arguments")