mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 15:22:37 +00:00
fix: resolve all strict mypy errors across crewai-tools package
This commit is contained in:
@@ -27,8 +27,8 @@ class NL2SQLTool(BaseTool):
|
||||
title="Database URI",
|
||||
description="The URI of the database to connect to.",
|
||||
)
|
||||
tables: list = Field(default_factory=list)
|
||||
columns: dict = Field(default_factory=dict)
|
||||
tables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
|
||||
args_schema: type[BaseModel] = NL2SQLToolInput
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
@@ -37,8 +37,11 @@ class NL2SQLTool(BaseTool):
|
||||
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
||||
)
|
||||
|
||||
data = {}
|
||||
tables = self._fetch_available_tables()
|
||||
data: dict[str, list[dict[str, Any]] | str] = {}
|
||||
result = self._fetch_available_tables()
|
||||
if isinstance(result, str):
|
||||
raise RuntimeError(f"Failed to fetch tables: {result}")
|
||||
tables: list[dict[str, Any]] = result
|
||||
|
||||
for table in tables:
|
||||
table_columns = self._fetch_all_available_columns(table["table_name"])
|
||||
@@ -47,17 +50,19 @@ class NL2SQLTool(BaseTool):
|
||||
self.tables = tables
|
||||
self.columns = data
|
||||
|
||||
def _fetch_available_tables(self):
|
||||
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
|
||||
return self.execute_sql(
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
||||
)
|
||||
|
||||
def _fetch_all_available_columns(self, table_name: str):
|
||||
def _fetch_all_available_columns(
|
||||
self, table_name: str
|
||||
) -> list[dict[str, Any]] | str:
|
||||
return self.execute_sql(
|
||||
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608
|
||||
)
|
||||
|
||||
def _run(self, sql_query: str):
|
||||
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
try:
|
||||
data = self.execute_sql(sql_query)
|
||||
except Exception as exc:
|
||||
@@ -69,7 +74,7 @@ class NL2SQLTool(BaseTool):
|
||||
|
||||
return data
|
||||
|
||||
def execute_sql(self, sql_query: str) -> list | str:
|
||||
def execute_sql(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
if not SQLALCHEMY_AVAILABLE:
|
||||
raise ImportError(
|
||||
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
||||
|
||||
Reference in New Issue
Block a user