mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 09:38:17 +00:00
Custom model config for RAG tools
This commit is contained in:
@@ -1,28 +1,47 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, validator
|
||||
from pydantic.v1 import BaseModel as V1BaseModel
|
||||
|
||||
from langchain_core.tools import StructuredTool
|
||||
|
||||
class BaseTool(BaseModel, ABC):
|
||||
class _ArgsSchemaPlaceholder(V1BaseModel):
|
||||
pass
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Used to tell the model how/when/why to use the tool."""
|
||||
args_schema: Optional[Type[V1BaseModel]] = None
|
||||
args_schema: Type[V1BaseModel] = Field(default_factory=_ArgsSchemaPlaceholder)
|
||||
"""The schema for the arguments that the tool accepts."""
|
||||
description_updated: bool = False
|
||||
"""Flag to check if the description has been updated."""
|
||||
cache_function: Optional[Callable] = lambda: True
|
||||
"""Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_args_schema(self):
|
||||
self._set_args_schema()
|
||||
@validator("args_schema", always=True, pre=True)
|
||||
def _default_args_schema(cls, v: Type[V1BaseModel]) -> Type[V1BaseModel]:
|
||||
if not isinstance(v, cls._ArgsSchemaPlaceholder):
|
||||
return v
|
||||
|
||||
return type(
|
||||
f"{cls.__name__}Schema",
|
||||
(V1BaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in cls._run.__annotations__.items() if k != "return"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._generate_description()
|
||||
return self
|
||||
|
||||
super().model_post_init(__context)
|
||||
|
||||
def run(
|
||||
self,
|
||||
@@ -57,16 +76,20 @@ class BaseTool(BaseModel, ABC):
|
||||
(V1BaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in self._run.__annotations__.items() if k != 'return'
|
||||
k: v
|
||||
for k, v in self._run.__annotations__.items()
|
||||
if k != "return"
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _generate_description(self):
|
||||
args = []
|
||||
for arg, attribute in self.args_schema.schema()['properties'].items():
|
||||
args.append(f"{arg}: '{attribute['type']}'")
|
||||
for arg, attribute in self.args_schema.schema()["properties"].items():
|
||||
if "type" in attribute:
|
||||
args.append(f"{arg}: '{attribute['type']}'")
|
||||
|
||||
description = self.description.replace('\n', ' ')
|
||||
description = self.description.replace("\n", " ")
|
||||
self.description = f"{self.name}({', '.join(args)}) - {description}"
|
||||
|
||||
|
||||
@@ -93,19 +116,19 @@ def tool(*args):
|
||||
def _make_tool(f: Callable) -> BaseTool:
|
||||
if f.__doc__ is None:
|
||||
raise ValueError("Function must have a docstring")
|
||||
if f.__annotations__ is None:
|
||||
raise ValueError("Function must have type annotations")
|
||||
|
||||
args_schema = None
|
||||
if f.__annotations__:
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = type(
|
||||
class_name,
|
||||
(V1BaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in f.__annotations__.items() if k != 'return'
|
||||
},
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = type(
|
||||
class_name,
|
||||
(V1BaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in f.__annotations__.items() if k != "return"
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
@@ -120,4 +143,4 @@ def tool(*args):
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
return _make_with_name(args[0])
|
||||
raise ValueError("Invalid arguments")
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
Reference in New Issue
Block a user