diff --git a/src/crewai_tools/tools/mysql_seach_tool/mysql_search_tool.py b/src/crewai_tools/tools/mysql_seach_tool/mysql_search_tool.py index 226fb1ddd..372a02f38 100644 --- a/src/crewai_tools/tools/mysql_seach_tool/mysql_search_tool.py +++ b/src/crewai_tools/tools/mysql_seach_tool/mysql_search_tool.py @@ -1,13 +1,13 @@ from typing import Any, Type -from embedchain.loaders.postgres import PostgresLoader +from embedchain.loaders.mysql import MySQLLoader from pydantic.v1 import BaseModel, Field from ..rag.rag_tool import RagTool -class PGSearchToolSchema(BaseModel): - """Input for PGSearchTool.""" +class MySQLSearchToolSchema(BaseModel): + """Input for MySQLSearchTool.""" search_query: str = Field( ..., @@ -15,10 +15,10 @@ class PGSearchToolSchema(BaseModel): ) -class PGSearchTool(RagTool): +class MySQLSearchTool(RagTool): name: str = "Search a database's table content" description: str = "A tool that can be used to semantic search a query from a database table's content." - args_schema: Type[BaseModel] = PGSearchToolSchema + args_schema: Type[BaseModel] = MySQLSearchToolSchema db_uri: str = Field(..., description="Mandatory database URI") def __init__(self, table_name: str, **kwargs): @@ -32,8 +32,8 @@ class PGSearchTool(RagTool): table_name: str, **kwargs: Any, ) -> None: - kwargs["data_type"] = "postgres" - kwargs["loader"] = PostgresLoader(config=dict(url=self.db_uri)) + kwargs["data_type"] = "mysql" + kwargs["loader"] = MySQLLoader(config=dict(url=self.db_uri)) super().add(f"SELECT * FROM {table_name};", **kwargs) def _run(