Move off v1

This commit is contained in:
Brandon Hancock
2024-09-03 15:57:29 -04:00
parent d19bba72b0
commit 35fe222ca1
39 changed files with 752 additions and 550 deletions

View File

@@ -3,11 +3,11 @@ from typing import Any, Callable, Optional, Type
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, ConfigDict, Field, validator
from pydantic.v1 import BaseModel as V1BaseModel
from pydantic import BaseModel as PydanticBaseModel
class BaseTool(BaseModel, ABC):
class _ArgsSchemaPlaceholder(V1BaseModel):
class _ArgsSchemaPlaceholder(PydanticBaseModel):
pass
model_config = ConfigDict()
@@ -16,7 +16,7 @@ class BaseTool(BaseModel, ABC):
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Used to tell the model how/when/why to use the tool."""
args_schema: Type[V1BaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
args_schema: Type[PydanticBaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
"""The schema for the arguments that the tool accepts."""
description_updated: bool = False
"""Flag to check if the description has been updated."""
@@ -26,13 +26,15 @@ class BaseTool(BaseModel, ABC):
"""Flag to check if the tool should be the final agent answer."""
@validator("args_schema", always=True, pre=True)
def _default_args_schema(cls, v: Type[V1BaseModel]) -> Type[V1BaseModel]:
def _default_args_schema(
cls, v: Type[PydanticBaseModel]
) -> Type[PydanticBaseModel]:
if not isinstance(v, cls._ArgsSchemaPlaceholder):
return v
return type(
f"{cls.__name__}Schema",
(V1BaseModel,),
(PydanticBaseModel,),
{
"__annotations__": {
k: v for k, v in cls._run.__annotations__.items() if k != "return"
@@ -75,7 +77,7 @@ class BaseTool(BaseModel, ABC):
class_name = f"{self.__class__.__name__}Schema"
self.args_schema = type(
class_name,
(V1BaseModel,),
(PydanticBaseModel,),
{
"__annotations__": {
k: v
@@ -127,7 +129,7 @@ def tool(*args):
class_name = "".join(tool_name.split()).title()
args_schema = type(
class_name,
(V1BaseModel,),
(PydanticBaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != "return"