mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 23:58:15 +00:00
Move off v1
This commit is contained in:
@@ -3,11 +3,11 @@ from typing import Any, Callable, Optional, Type
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, validator
|
||||
from pydantic.v1 import BaseModel as V1BaseModel
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
class _ArgsSchemaPlaceholder(V1BaseModel):
|
||||
class _ArgsSchemaPlaceholder(PydanticBaseModel):
|
||||
pass
|
||||
|
||||
model_config = ConfigDict()
|
||||
@@ -16,7 +16,7 @@ class BaseTool(BaseModel, ABC):
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool."""
|
||||
args_schema: Type[V1BaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
|
||||
args_schema: Type[PydanticBaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
|
||||
"""The schema for the arguments that the tool accepts."""
|
||||
description_updated: bool = False
|
||||
"""Flag to check if the description has been updated."""
|
||||
@@ -26,13 +26,15 @@ class BaseTool(BaseModel, ABC):
|
||||
"""Flag to check if the tool should be the final agent answer."""
|
||||
|
||||
@validator("args_schema", always=True, pre=True)
|
||||
def _default_args_schema(cls, v: Type[V1BaseModel]) -> Type[V1BaseModel]:
|
||||
def _default_args_schema(
|
||||
cls, v: Type[PydanticBaseModel]
|
||||
) -> Type[PydanticBaseModel]:
|
||||
if not isinstance(v, cls._ArgsSchemaPlaceholder):
|
||||
return v
|
||||
|
||||
return type(
|
||||
f"{cls.__name__}Schema",
|
||||
(V1BaseModel,),
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in cls._run.__annotations__.items() if k != "return"
|
||||
@@ -75,7 +77,7 @@ class BaseTool(BaseModel, ABC):
|
||||
class_name = f"{self.__class__.__name__}Schema"
|
||||
self.args_schema = type(
|
||||
class_name,
|
||||
(V1BaseModel,),
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v
|
||||
@@ -127,7 +129,7 @@ def tool(*args):
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = type(
|
||||
class_name,
|
||||
(V1BaseModel,),
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in f.__annotations__.items() if k != "return"
|
||||
|
||||
Reference in New Issue
Block a user