Custom model config for RAG tools

This commit is contained in:
Gui Vieira
2024-03-19 18:47:13 -03:00
parent 73cae1997d
commit 1c8d010601
20 changed files with 704 additions and 452 deletions

View File

@@ -1,28 +1,47 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Type
from pydantic import BaseModel, model_validator
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, ConfigDict, Field, validator
from pydantic.v1 import BaseModel as V1BaseModel
from langchain_core.tools import StructuredTool
class BaseTool(BaseModel, ABC):
class _ArgsSchemaPlaceholder(V1BaseModel):
pass
model_config = ConfigDict()
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool."""
args_schema: Optional[Type[V1BaseModel]] = None
args_schema: Type[V1BaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
"""The schema for the arguments that the tool accepts."""
description_updated: bool = False
"""Flag to check if the description has been updated."""
cache_function: Optional[Callable] = lambda: True
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
@model_validator(mode="after")
def _check_args_schema(self):
self._set_args_schema()
@validator("args_schema", always=True, pre=True)
def _default_args_schema(cls, v: Type[V1BaseModel]) -> Type[V1BaseModel]:
if not isinstance(v, cls._ArgsSchemaPlaceholder):
return v
return type(
f"{cls.__name__}Schema",
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in cls._run.__annotations__.items() if k != "return"
},
},
)
def model_post_init(self, __context: Any) -> None:
self._generate_description()
return self
super().model_post_init(__context)
def run(
self,
@@ -57,16 +76,20 @@ class BaseTool(BaseModel, ABC):
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in self._run.__annotations__.items() if k != 'return'
k: v
for k, v in self._run.__annotations__.items()
if k != "return"
},
},
)
def _generate_description(self):
args = []
for arg, attribute in self.args_schema.schema()['properties'].items():
args.append(f"{arg}: '{attribute['type']}'")
for arg, attribute in self.args_schema.schema()["properties"].items():
if "type" in attribute:
args.append(f"{arg}: '{attribute['type']}'")
description = self.description.replace('\n', ' ')
description = self.description.replace("\n", " ")
self.description = f"{self.name}({', '.join(args)}) - {description}"
@@ -93,19 +116,19 @@ def tool(*args):
def _make_tool(f: Callable) -> BaseTool:
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
if f.__annotations__ is None:
raise ValueError("Function must have type annotations")
args_schema = None
if f.__annotations__:
class_name = "".join(tool_name.split()).title()
args_schema = type(
class_name,
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != 'return'
},
class_name = "".join(tool_name.split()).title()
args_schema = type(
class_name,
(V1BaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != "return"
},
)
},
)
return Tool(
name=tool_name,
@@ -120,4 +143,4 @@ def tool(*args):
return _make_with_name(args[0].__name__)(args[0])
if len(args) == 1 and isinstance(args[0], str):
return _make_with_name(args[0])
raise ValueError("Invalid arguments")
raise ValueError("Invalid arguments")