adding MySQLSearcherTool

This commit is contained in:
Carlos Antunes
2024-05-18 16:58:40 -03:00
parent e36af697cd
commit a11cc57345

View File

@@ -1,13 +1,13 @@
from typing import Any, Type
from embedchain.loaders.postgres import PostgresLoader
from embedchain.loaders.mysql import MySQLLoader
from pydantic.v1 import BaseModel, Field
from ..rag.rag_tool import RagTool
class PGSearchToolSchema(BaseModel):
"""Input for PGSearchTool."""
class MySQLSearchToolSchema(BaseModel):
"""Input for MySQLSearchTool."""
search_query: str = Field(
...,
@@ -15,10 +15,10 @@ class PGSearchToolSchema(BaseModel):
)
class PGSearchTool(RagTool):
class MySQLSearchTool(RagTool):
name: str = "Search a database's table content"
description: str = "A tool that can be used to semantic search a query from a database table's content."
args_schema: Type[BaseModel] = PGSearchToolSchema
args_schema: Type[BaseModel] = MySQLSearchToolSchema
db_uri: str = Field(..., description="Mandatory database URI")
def __init__(self, table_name: str, **kwargs):
@@ -32,8 +32,8 @@ class PGSearchTool(RagTool):
table_name: str,
**kwargs: Any,
) -> None:
kwargs["data_type"] = "postgres"
kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri))
kwargs["data_type"] = "mysql"
kwargs["loader"] = MySQLLoader(config=dict(url=self.db_uri))
super().add(f"SELECT * FROM {table_name};", **kwargs)
def _run(