mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-12 22:12:37 +00:00
103 lines
3.4 KiB
Python
103 lines
3.4 KiB
Python
from typing import Any
|
|
|
|
from crewai.tools import BaseTool
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
try:
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
SQLALCHEMY_AVAILABLE = True
|
|
except ImportError:
|
|
SQLALCHEMY_AVAILABLE = False
|
|
|
|
|
|
class NL2SQLToolInput(BaseModel):
|
|
sql_query: str = Field(
|
|
title="SQL Query",
|
|
description="The SQL query to execute.",
|
|
)
|
|
|
|
|
|
class NL2SQLTool(BaseTool):
|
|
name: str = "NL2SQLTool"
|
|
description: str = "Converts natural language to SQL queries and executes them."
|
|
db_uri: str = Field(
|
|
title="Database URI",
|
|
description="The URI of the database to connect to.",
|
|
)
|
|
tables: list[dict[str, Any]] = Field(default_factory=list)
|
|
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
|
|
args_schema: type[BaseModel] = NL2SQLToolInput
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
if not SQLALCHEMY_AVAILABLE:
|
|
raise ImportError(
|
|
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
|
)
|
|
|
|
data: dict[str, list[dict[str, Any]] | str] = {}
|
|
result = self._fetch_available_tables()
|
|
if isinstance(result, str):
|
|
raise RuntimeError(f"Failed to fetch tables: {result}")
|
|
tables: list[dict[str, Any]] = result
|
|
|
|
for table in tables:
|
|
table_columns = self._fetch_all_available_columns(table["table_name"])
|
|
data[f"{table['table_name']}_columns"] = table_columns
|
|
|
|
self.tables = tables
|
|
self.columns = data
|
|
|
|
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
|
|
return self.execute_sql(
|
|
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
|
)
|
|
|
|
def _fetch_all_available_columns(
|
|
self, table_name: str
|
|
) -> list[dict[str, Any]] | str:
|
|
return self.execute_sql(
|
|
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608
|
|
)
|
|
|
|
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
|
try:
|
|
data = self.execute_sql(sql_query)
|
|
except Exception as exc:
|
|
data = (
|
|
f"Based on these tables {self.tables} and columns {self.columns}, "
|
|
"you can create SQL queries to retrieve data from the database."
|
|
f"Get the original request {sql_query} and the error {exc} and create the correct SQL query."
|
|
)
|
|
|
|
return data
|
|
|
|
def execute_sql(self, sql_query: str) -> list[dict[str, Any]] | str:
|
|
if not SQLALCHEMY_AVAILABLE:
|
|
raise ImportError(
|
|
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
|
)
|
|
|
|
engine = create_engine(self.db_uri)
|
|
Session = sessionmaker(bind=engine) # noqa: N806
|
|
session = Session()
|
|
try:
|
|
result = session.execute(text(sql_query))
|
|
session.commit()
|
|
|
|
if result.returns_rows: # type: ignore[attr-defined]
|
|
columns = result.keys()
|
|
return [
|
|
dict(zip(columns, row, strict=False)) for row in result.fetchall()
|
|
]
|
|
return f"Query {sql_query} executed successfully"
|
|
|
|
except Exception as e:
|
|
session.rollback()
|
|
raise e
|
|
|
|
finally:
|
|
session.close()
|