Custom model config for RAG tools

This commit is contained in:
Gui Vieira
2024-03-19 18:47:13 -03:00
parent 73cae1997d
commit 1c8d010601
20 changed files with 704 additions and 452 deletions

View File

@@ -1,45 +1,37 @@
from typing import Optional, Type, Any
from pydantic.v1 import BaseModel, Field
from typing import Any, Type
from embedchain import App
from embedchain.loaders.postgres import PostgresLoader
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class PGSearchToolSchema(BaseModel):
"""Input for PGSearchTool."""
search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content")
"""Input for PGSearchTool."""
search_query: str = Field(
...,
description="Mandatory semantic search query you want to use to search the database's content",
)
class PGSearchTool(RagTool):
name: str = "Search a database's table content"
description: str = "A tool that can be used to semantic search a query from a database table's content."
summarize: bool = False
args_schema: Type[BaseModel] = PGSearchToolSchema
db_uri: str = Field(..., description="Mandatory database URI")
table_name: str = Field(..., description="Mandatory table name")
search_query: str = Field(..., description="Mandatory semantic search query you want to use to search the database's content")
name: str = "Search a database's table content"
description: str = "A tool that can be used to semantic search a query from a database table's content."
args_schema: Type[BaseModel] = PGSearchToolSchema
db_uri: str = Field(..., description="Mandatory database URI")
def __init__(self, table_name: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if table_name is not None:
self.table_name = table_name
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
self._generate_description()
else:
raise('To use PGSearchTool, you must provide a `table_name` argument')
def __init__(self, table_name: str, **kwargs):
super().__init__(**kwargs)
self.add(table_name)
self.description = f"A tool that can be used to semantic search a query the {table_name} database table's content."
self._generate_description()
def _run(
self,
search_query: str,
**kwargs: Any,
) -> Any:
config = { "url": self.db_uri }
postgres_loader = PostgresLoader(config=config)
app = App()
app.add(
f"SELECT * FROM {self.table_name};",
data_type='postgres',
loader=postgres_loader
)
return super()._run(query=search_query)
def add(
self,
table_name: str,
**kwargs: Any,
) -> None:
kwargs["data_type"] = "postgres"
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
super().add(f"SELECT * FROM {table_name};", **kwargs)