mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: Add nl2sql tool to run and execute sql queries in databases
This commit is contained in:
0
src/crewai_tools/tools/nl2sql/README.md
Normal file
0
src/crewai_tools/tools/nl2sql/README.md
Normal file
74
src/crewai_tools/tools/nl2sql/nl2sql_tool.py
Normal file
74
src/crewai_tools/tools/nl2sql/nl2sql_tool.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from crewai_tools import BaseTool
|
||||||
|
from pydantic import Field
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
|
||||||
|
class NL2SQL(BaseTool):
|
||||||
|
name: str = "NL2SQL"
|
||||||
|
description: str = "Converts natural language to SQL queries and executes them."
|
||||||
|
db_uri: str = Field(
|
||||||
|
title="Database URI",
|
||||||
|
description="The URI of the database to connect to.",
|
||||||
|
)
|
||||||
|
tables: list = []
|
||||||
|
columns: dict = {}
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
data = {}
|
||||||
|
tables = self._fetch_available_tables()
|
||||||
|
|
||||||
|
for table in tables:
|
||||||
|
table_columns = self._fetch_all_available_columns(table["table_name"])
|
||||||
|
data[f'{table["table_name"]}_columns'] = table_columns
|
||||||
|
|
||||||
|
self.tables = tables
|
||||||
|
self.columns = data
|
||||||
|
|
||||||
|
def _fetch_available_tables(self):
|
||||||
|
return self.execute_sql(
|
||||||
|
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fetch_all_available_columns(self, table_name: str):
|
||||||
|
return self.execute_sql(
|
||||||
|
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run(self, sql_query: str):
|
||||||
|
try:
|
||||||
|
data = self.execute_sql(sql_query)
|
||||||
|
except Exception:
|
||||||
|
data = (
|
||||||
|
f"Based on these tables {self.tables} and columns {self.columns}, "
|
||||||
|
"you can create SQL queries to retrieve data from the database."
|
||||||
|
f"Get the original request {sql_query} and try to create a SQL query that retrieves the requested data."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def execute_sql(self, sql_query: str) -> Union[list, str]:
|
||||||
|
engine = create_engine(self.db_uri)
|
||||||
|
Session = sessionmaker(bind=engine)
|
||||||
|
session = Session()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = session.execute(text(sql_query))
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
if result.returns_rows:
|
||||||
|
columns = result.keys()
|
||||||
|
data = [dict(zip(columns, row)) for row in result.fetchall()]
|
||||||
|
return data
|
||||||
|
else:
|
||||||
|
return f"Query {sql_query} executed successfully"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
print(f"SQL execution error: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
Reference in New Issue
Block a user