mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 13:28:13 +00:00
revamping crewai tool
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from .tools.base_tool import BaseTool, Tool, as_tool, tool
|
||||
from .tools.base_tool import BaseTool, Tool, tool
|
||||
from .tools import (
|
||||
CodeDocsSearchTool,
|
||||
CSVSearchTool,
|
||||
|
||||
@@ -1,18 +1,24 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, cast, Optional, Type
|
||||
|
||||
from langchain.agents import tools as langchain_tools
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
from pydantic.v1 import BaseModel as V1BaseModel
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool."""
|
||||
args_schema: Optional[Type[BaseModel]] = None
|
||||
args_schema: Optional[Type[V1BaseModel]] = None
|
||||
"""The schema for the arguments that the tool accepts."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_args_schema(self):
|
||||
self._set_args_schema()
|
||||
return self
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
@@ -29,14 +35,28 @@ class BaseTool(BaseModel, ABC):
|
||||
) -> Any:
|
||||
"""Here goes the actual implementation of the tool."""
|
||||
|
||||
def to_langchain(self) -> langchain_tools.Tool:
|
||||
return langchain_tools.Tool(
|
||||
def to_langchain(self) -> StructuredTool:
|
||||
self._set_args_schema()
|
||||
return StructuredTool(
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
args_schema=self.args_schema,
|
||||
func=self._run,
|
||||
)
|
||||
|
||||
def _set_args_schema(self):
|
||||
if self.args_schema is None:
|
||||
class_name = f"{self.__class__.__name__}Schema"
|
||||
self.args_schema = type(
|
||||
class_name,
|
||||
(V1BaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in self._run.__annotations__.items() if k != 'return'
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
func: Callable
|
||||
@@ -47,8 +67,8 @@ class Tool(BaseTool):
|
||||
|
||||
|
||||
def to_langchain(
|
||||
tools: list[BaseTool | langchain_tools.BaseTool],
|
||||
) -> list[langchain_tools.BaseTool]:
|
||||
tools: list[BaseTool | StructuredTool],
|
||||
) -> list[StructuredTool]:
|
||||
return [t.to_langchain() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
@@ -62,10 +82,24 @@ def tool(*args):
|
||||
if f.__doc__ is None:
|
||||
raise ValueError("Function must have a docstring")
|
||||
|
||||
args_schema = None
|
||||
if f.__annotations__:
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = type(
|
||||
class_name,
|
||||
(V1BaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in f.__annotations__.items() if k != 'return'
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
description=f.__doc__,
|
||||
func=f,
|
||||
args_schema=args_schema,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
@@ -74,13 +108,4 @@ def tool(*args):
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
return _make_with_name(args[0])
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
|
||||
def as_tool(f: Any) -> BaseTool:
|
||||
"""
|
||||
Useful for when you create a tool using the @tool decorator and want to use it as a BaseTool.
|
||||
It is a BaseTool, but type inference doesn't know that.
|
||||
"""
|
||||
assert isinstance(f, BaseTool)
|
||||
return cast(BaseTool, f)
|
||||
raise ValueError("Invalid arguments")
|
||||
@@ -29,5 +29,9 @@ class DirectoryReadTool(BaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
directory = kwargs.get('directory', self.directory)
|
||||
return [(os.path.join(root, file).replace(directory, "").lstrip(os.path.sep)) for root, dirs, files in os.walk(directory) for file in files]
|
||||
if directory[-1] == "/":
|
||||
directory = directory[:-1]
|
||||
files_list = [f"{directory}/{(os.path.join(root, filename).replace(directory, '').lstrip(os.path.sep))}" for root, dirs, files in os.walk(directory) for filename in files]
|
||||
files = "\n- ".join(files_list)
|
||||
return f"File paths: \n-{files}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user