revamping crewai tool

This commit is contained in:
João Moura
2024-02-25 21:11:09 -03:00
parent 7c99e9ab50
commit 50bae27948
21 changed files with 100 additions and 144 deletions

View File

@@ -1,4 +1,4 @@
from .tools.base_tool import BaseTool, Tool, as_tool, tool
from .tools.base_tool import BaseTool, Tool, tool
from .tools import (
CodeDocsSearchTool,
CSVSearchTool,

View File

@@ -1,18 +1,24 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, cast, Optional, Type
from langchain.agents import tools as langchain_tools
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from pydantic.v1 import BaseModel as V1BaseModel
from langchain_core.tools import StructuredTool
class BaseTool(BaseModel, ABC):
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool."""
args_schema: Optional[Type[BaseModel]] = None
args_schema: Optional[Type[V1BaseModel]] = None
"""The schema for the arguments that the tool accepts."""
@model_validator(mode="after")
def _check_args_schema(self):
self._set_args_schema()
return self
def run(
self,
*args: Any,
@@ -29,14 +35,28 @@ class BaseTool(BaseModel, ABC):
) -> Any:
"""Here goes the actual implementation of the tool."""
def to_langchain(self) -> langchain_tools.Tool:
return langchain_tools.Tool(
def to_langchain(self) -> StructuredTool:
self._set_args_schema()
return StructuredTool(
name=self.name,
description=self.description,
args_schema=self.args_schema,
func=self._run,
)
def _set_args_schema(self):
if self.args_schema is None:
class_name = f"{self.__class__.__name__}Schema"
self.args_schema = type(
class_name,
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in self._run.__annotations__.items() if k != 'return'
},
},
)
class Tool(BaseTool):
func: Callable
@@ -47,8 +67,8 @@ class Tool(BaseTool):
def to_langchain(
tools: list[BaseTool | langchain_tools.BaseTool],
) -> list[langchain_tools.BaseTool]:
tools: list[BaseTool | StructuredTool],
) -> list[StructuredTool]:
return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools]
@@ -62,10 +82,24 @@ def tool(*args):
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
args_schema = None
if f.__annotations__:
class_name = "".join(tool_name.split()).title()
args_schema = type(
class_name,
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != 'return'
},
},
)
return Tool(
name=tool_name,
description=f.__doc__,
func=f,
args_schema=args_schema,
)
return _make_tool
@@ -74,13 +108,4 @@ def tool(*args):
return _make_with_name(args[0].__name__)(args[0])
if len(args) == 1 and isinstance(args[0], str):
return _make_with_name(args[0])
raise ValueError("Invalid arguments")
def as_tool(f: Any) -> BaseTool:
"""
Useful for when you create a tool using the @tool decorator and want to use it as a BaseTool.
It is a BaseTool, but type inference doesn't know that.
"""
assert isinstance(f, BaseTool)
return cast(BaseTool, f)
raise ValueError("Invalid arguments")

View File

@@ -29,5 +29,9 @@ class DirectoryReadTool(BaseTool):
**kwargs: Any,
) -> Any:
directory = kwargs.get('directory', self.directory)
return [(os.path.join(root, file).replace(directory, "").lstrip(os.path.sep)) for root, dirs, files in os.walk(directory) for file in files]
if directory[-1] == "/":
directory = directory[:-1]
files_list = [f"{directory}/{(os.path.join(root, filename).replace(directory, '').lstrip(os.path.sep))}" for root, dirs, files in os.walk(directory) for filename in files]
files = "\n- ".join(files_list)
return f"File paths: \n-{files}"