mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-11 05:22:41 +00:00
Some checks failed
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
* fix: harden NL2SQLTool — read-only by default, parameterized queries, query validation Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: address CI lint failures and remove unused import - Remove unused `sessionmaker` import from test_nl2sql_security.py - Use `Self` return type on `_apply_env_override` (fixes UP037/F821) - Fix ruff errors auto-fixed in lib/crewai (UP007, etc.) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: expand _WRITE_COMMANDS and block multi-statement semicolon injection - Add missing write commands: UPSERT, LOAD, COPY, VACUUM, ANALYZE, ANALYSE, REINDEX, CLUSTER, REFRESH, COMMENT, SET, RESET - _validate_query() now splits on ';' and validates each statement independently; multi-statement queries are rejected outright in read-only mode to prevent 'SELECT 1; DROP TABLE users' bypass - Extract single-statement logic into _validate_statement() helper - Add TestSemicolonInjection and TestExtendedWriteCommands test classes Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * ci: retrigger * fix: use typing_extensions.Self for Python 3.10 compat * chore: update tool specifications * docs: document NL2SQLTool read-only default and DML configuration Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: close three NL2SQLTool security gaps (writable CTEs, EXPLAIN ANALYZE, multi-stmt commit) - Remove WITH from _READ_ONLY_COMMANDS; scan CTE body for write keywords so writable CTEs like `WITH d AS (DELETE …) SELECT …` are blocked in read-only mode. - EXPLAIN ANALYZE/ANALYSE now resolves the underlying command; EXPLAIN ANALYZE DELETE is treated as a write and blocked in read-only mode. - execute_sql commit decision now checks ALL semicolon-separated statements so a SELECT-first batch like `SELECT 1; DROP TABLE t` still triggers a commit when allow_dml=True. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: handle parenthesized EXPLAIN options syntax; remove unused _seed_db _validate_statement now strips parenthesized options from EXPLAIN (e.g. EXPLAIN (ANALYZE) DELETE, EXPLAIN (ANALYZE, VERBOSE) DELETE) before checking whether ANALYZE/ANALYSE is present — closing the bypass where the options-list form was silently allowed in read-only mode. Adds three new tests: - EXPLAIN (ANALYZE) DELETE → blocked - EXPLAIN (ANALYZE, VERBOSE) DELETE → blocked - EXPLAIN (VERBOSE) SELECT → allowed Also removes the unused _seed_db helper from test_nl2sql_security.py. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * chore: update tool specifications * fix: smarter CTE write detection, fix commit logic for writable CTEs - Replace naive token-set matching with positional AS() body inspection to avoid false positives on column names like 'comment', 'set', 'reset' - Fix execute_sql commit logic to detect writable CTEs (WITH + DELETE/INSERT) not just top-level write commands - Add tests for false positive cases and writable CTE commit behavior - Format nl2sql_tool.py to pass ruff format check * fix: catch write commands in CTE main query + handle whitespace in AS() - WITH cte AS (SELECT 1) DELETE FROM users now correctly blocked - AS followed by newline/tab/multi-space before ( now detected - execute_sql commit logic updated for both cases - 4 new tests * fix: EXPLAIN ANALYZE VERBOSE handling, string literal paren bypass, commit logic for EXPLAIN ANALYZE - EXPLAIN handler now consumes all known options (ANALYZE, ANALYSE, VERBOSE) before extracting the real command, fixing 'EXPLAIN ANALYZE VERBOSE SELECT' being blocked - Paren walker in _extract_main_query_after_cte now skips string literals, preventing 'WITH cte AS (SELECT '\''('\'' FROM t) DELETE FROM users' from bypassing detection - _is_write_stmt in execute_sql now resolves EXPLAIN ANALYZE to underlying command via _resolve_explain_command, ensuring session.commit() fires for write operations - 10 new tests covering all three fixes * fix: deduplicate EXPLAIN parsing, fix AS( regex in strings, block unknown CTE commands, bump langchain-core - Refactor _validate_statement to use _resolve_explain_command (single source of truth) - _iter_as_paren_matches skips string literals so 'AS (' in data doesn't confuse CTE detection - Unknown commands after CTE definitions now blocked in read-only mode - Bump langchain-core override to >=1.2.28 (GHSA-926x-3r5x-gfhw) * fix: add return type annotation to _iter_as_paren_matches --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
517 lines
18 KiB
Python
517 lines
18 KiB
Python
from collections.abc import Iterator
|
|
import logging
|
|
import os
|
|
import re
|
|
from typing import Any
|
|
|
|
|
|
try:
|
|
from typing import Self
|
|
except ImportError:
|
|
from typing_extensions import Self
|
|
|
|
from crewai.tools import BaseTool
|
|
from pydantic import BaseModel, Field, model_validator
|
|
|
|
|
|
try:
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
SQLALCHEMY_AVAILABLE = True
|
|
except ImportError:
|
|
SQLALCHEMY_AVAILABLE = False
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Commands allowed in read-only mode
|
|
# NOTE: WITH is intentionally excluded — writable CTEs start with WITH, so the
|
|
# CTE body must be inspected separately (see _validate_statement).
|
|
_READ_ONLY_COMMANDS = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"}
|
|
|
|
# Commands that mutate state and are blocked by default
|
|
_WRITE_COMMANDS = {
|
|
"INSERT",
|
|
"UPDATE",
|
|
"DELETE",
|
|
"DROP",
|
|
"ALTER",
|
|
"CREATE",
|
|
"TRUNCATE",
|
|
"GRANT",
|
|
"REVOKE",
|
|
"EXEC",
|
|
"EXECUTE",
|
|
"CALL",
|
|
"MERGE",
|
|
"REPLACE",
|
|
"UPSERT",
|
|
"LOAD",
|
|
"COPY",
|
|
"VACUUM",
|
|
"ANALYZE",
|
|
"ANALYSE",
|
|
"REINDEX",
|
|
"CLUSTER",
|
|
"REFRESH",
|
|
"COMMENT",
|
|
"SET",
|
|
"RESET",
|
|
}
|
|
|
|
|
|
# Subset of write commands that can realistically appear *inside* a CTE body.
|
|
# Narrower than _WRITE_COMMANDS to avoid false positives on identifiers like
|
|
# ``comment``, ``set``, or ``reset`` which are common column/table names.
|
|
_CTE_WRITE_INDICATORS = {
|
|
"INSERT",
|
|
"UPDATE",
|
|
"DELETE",
|
|
"DROP",
|
|
"ALTER",
|
|
"CREATE",
|
|
"TRUNCATE",
|
|
"MERGE",
|
|
}
|
|
|
|
|
|
_AS_PAREN_RE = re.compile(r"\bAS\s*\(", re.IGNORECASE)
|
|
|
|
|
|
def _iter_as_paren_matches(stmt: str) -> Iterator[re.Match[str]]:
|
|
"""Yield regex matches for ``AS\\s*(`` outside of string literals."""
|
|
# Build a set of character positions that are inside string literals.
|
|
in_string: set[int] = set()
|
|
i = 0
|
|
while i < len(stmt):
|
|
if stmt[i] == "'":
|
|
start = i
|
|
end = _skip_string_literal(stmt, i)
|
|
in_string.update(range(start, end))
|
|
i = end
|
|
else:
|
|
i += 1
|
|
|
|
for m in _AS_PAREN_RE.finditer(stmt):
|
|
if m.start() not in in_string:
|
|
yield m
|
|
|
|
|
|
def _detect_writable_cte(stmt: str) -> str | None:
|
|
"""Return the first write command inside a CTE body, or None.
|
|
|
|
Instead of tokenizing the whole statement (which falsely matches column
|
|
names like ``comment``), this walks through parenthesized CTE bodies and
|
|
checks only the *first keyword after* an opening ``AS (`` for a write
|
|
command. Uses a regex to handle any whitespace (spaces, tabs, newlines)
|
|
between ``AS`` and ``(``. Skips matches inside string literals.
|
|
"""
|
|
for m in _iter_as_paren_matches(stmt):
|
|
body = stmt[m.end() :].lstrip()
|
|
first_word = body.split()[0].upper().strip("()") if body.split() else ""
|
|
if first_word in _CTE_WRITE_INDICATORS:
|
|
return first_word
|
|
return None
|
|
|
|
|
|
def _skip_string_literal(stmt: str, pos: int) -> int:
|
|
"""Skip past a string literal starting at pos (single-quoted).
|
|
|
|
Handles escaped quotes ('') inside the literal.
|
|
Returns the index after the closing quote.
|
|
"""
|
|
quote_char = stmt[pos]
|
|
i = pos + 1
|
|
while i < len(stmt):
|
|
if stmt[i] == quote_char:
|
|
# Check for escaped quote ('')
|
|
if i + 1 < len(stmt) and stmt[i + 1] == quote_char:
|
|
i += 2
|
|
continue
|
|
return i + 1
|
|
i += 1
|
|
return i # Unterminated literal — return end
|
|
|
|
|
|
def _find_matching_close_paren(stmt: str, start: int) -> int:
|
|
"""Find the matching close paren, skipping string literals."""
|
|
depth = 1
|
|
i = start
|
|
while i < len(stmt) and depth > 0:
|
|
ch = stmt[i]
|
|
if ch == "'":
|
|
i = _skip_string_literal(stmt, i)
|
|
continue
|
|
if ch == "(":
|
|
depth += 1
|
|
elif ch == ")":
|
|
depth -= 1
|
|
i += 1
|
|
return i
|
|
|
|
|
|
def _extract_main_query_after_cte(stmt: str) -> str | None:
|
|
"""Extract the main (outer) query that follows all CTE definitions.
|
|
|
|
For ``WITH cte AS (SELECT 1) DELETE FROM users``, returns ``DELETE FROM users``.
|
|
Returns None if no main query is found after the last CTE body.
|
|
Handles parentheses inside string literals (e.g., ``SELECT '(' FROM t``).
|
|
"""
|
|
last_cte_end = 0
|
|
for m in _iter_as_paren_matches(stmt):
|
|
last_cte_end = _find_matching_close_paren(stmt, m.end())
|
|
|
|
if last_cte_end > 0:
|
|
remainder = stmt[last_cte_end:].strip().lstrip(",").strip()
|
|
if remainder:
|
|
return remainder
|
|
return None
|
|
|
|
|
|
def _resolve_explain_command(stmt: str) -> str | None:
|
|
"""Resolve the underlying command from an EXPLAIN [ANALYZE] [VERBOSE] statement.
|
|
|
|
Returns the real command (e.g., 'DELETE') if ANALYZE is present, else None.
|
|
Handles both space-separated and parenthesized syntax.
|
|
"""
|
|
rest = stmt.strip()[len("EXPLAIN") :].strip()
|
|
if not rest:
|
|
return None
|
|
|
|
analyze_found = False
|
|
explain_opts = {"ANALYZE", "ANALYSE", "VERBOSE"}
|
|
|
|
if rest.startswith("("):
|
|
close = rest.find(")")
|
|
if close != -1:
|
|
options_str = rest[1:close].upper()
|
|
analyze_found = any(
|
|
opt.strip() in ("ANALYZE", "ANALYSE") for opt in options_str.split(",")
|
|
)
|
|
rest = rest[close + 1 :].strip()
|
|
else:
|
|
while rest:
|
|
first_opt = rest.split()[0].upper().rstrip(";") if rest.split() else ""
|
|
if first_opt in ("ANALYZE", "ANALYSE"):
|
|
analyze_found = True
|
|
if first_opt not in explain_opts:
|
|
break
|
|
rest = rest[len(first_opt) :].strip()
|
|
|
|
if analyze_found and rest:
|
|
return rest.split()[0].upper().rstrip(";")
|
|
return None
|
|
|
|
|
|
class NL2SQLToolInput(BaseModel):
|
|
sql_query: str = Field(
|
|
title="SQL Query",
|
|
description="The SQL query to execute.",
|
|
)
|
|
|
|
|
|
class NL2SQLTool(BaseTool):
|
|
"""Tool that converts natural language to SQL and executes it against a database.
|
|
|
|
By default the tool operates in **read-only mode**: only SELECT, SHOW,
|
|
DESCRIBE, EXPLAIN, and read-only CTEs (WITH … SELECT) are permitted. Write
|
|
operations (INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, TRUNCATE, …) are
|
|
blocked unless ``allow_dml=True`` is set explicitly or the environment
|
|
variable ``CREWAI_NL2SQL_ALLOW_DML=true`` is present.
|
|
|
|
Writable CTEs (``WITH d AS (DELETE …) SELECT …``) and
|
|
``EXPLAIN ANALYZE <write-stmt>`` are treated as write operations and are
|
|
blocked in read-only mode.
|
|
|
|
The ``_fetch_all_available_columns`` helper uses parameterised queries so
|
|
that table names coming from the database catalogue cannot be used as an
|
|
injection vector.
|
|
"""
|
|
|
|
name: str = "NL2SQLTool"
|
|
description: str = (
|
|
"Converts natural language to SQL queries and executes them against a "
|
|
"database. Read-only by default — only SELECT/SHOW/DESCRIBE/EXPLAIN "
|
|
"queries (and read-only CTEs) are allowed unless configured with "
|
|
"allow_dml=True."
|
|
)
|
|
db_uri: str = Field(
|
|
title="Database URI",
|
|
description="The URI of the database to connect to.",
|
|
)
|
|
allow_dml: bool = Field(
|
|
default=False,
|
|
title="Allow DML",
|
|
description=(
|
|
"When False (default) only read statements are permitted. "
|
|
"Set to True to allow INSERT/UPDATE/DELETE/DROP and other "
|
|
"write operations."
|
|
),
|
|
)
|
|
tables: list[dict[str, Any]] = Field(default_factory=list)
|
|
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
|
|
args_schema: type[BaseModel] = NL2SQLToolInput
|
|
|
|
@model_validator(mode="after")
|
|
def _apply_env_override(self) -> Self:
|
|
"""Allow CREWAI_NL2SQL_ALLOW_DML=true to override allow_dml at runtime."""
|
|
if os.environ.get("CREWAI_NL2SQL_ALLOW_DML", "").strip().lower() == "true":
|
|
if not self.allow_dml:
|
|
logger.warning(
|
|
"NL2SQLTool: CREWAI_NL2SQL_ALLOW_DML env var is set — "
|
|
"DML/DDL operations are enabled. Ensure this is intentional."
|
|
)
|
|
self.allow_dml = True
|
|
return self
|
|
|
|
def model_post_init(self, __context: Any) -> None:
|
|
if not SQLALCHEMY_AVAILABLE:
|
|
raise ImportError(
|
|
"sqlalchemy is not installed. Please install it with "
|
|
"`pip install crewai-tools[sqlalchemy]`"
|
|
)
|
|
|
|
if self.allow_dml:
|
|
logger.warning(
|
|
"NL2SQLTool: allow_dml=True — write operations (INSERT/UPDATE/"
|
|
"DELETE/DROP/…) are permitted. Use with caution."
|
|
)
|
|
|
|
data: dict[str, list[dict[str, Any]] | str] = {}
|
|
result = self._fetch_available_tables()
|
|
if isinstance(result, str):
|
|
raise RuntimeError(f"Failed to fetch tables: {result}")
|
|
tables: list[dict[str, Any]] = result
|
|
|
|
for table in tables:
|
|
table_columns = self._fetch_all_available_columns(table["table_name"])
|
|
data[f"{table['table_name']}_columns"] = table_columns
|
|
|
|
self.tables = tables
|
|
self.columns = data
|
|
|
|
# ------------------------------------------------------------------
|
|
# Query validation
|
|
# ------------------------------------------------------------------
|
|
|
|
def _validate_query(self, sql_query: str) -> None:
|
|
"""Raise ValueError if *sql_query* is not permitted under the current config.
|
|
|
|
Splits the query on semicolons and validates each statement
|
|
independently. When ``allow_dml=False`` (the default), multi-statement
|
|
queries are rejected outright to prevent ``SELECT 1; DROP TABLE users``
|
|
style bypasses. When ``allow_dml=True`` every statement is checked and
|
|
a warning is emitted for write operations.
|
|
"""
|
|
statements = [s.strip() for s in sql_query.split(";") if s.strip()]
|
|
|
|
if not statements:
|
|
raise ValueError("NL2SQLTool received an empty SQL query.")
|
|
|
|
if not self.allow_dml and len(statements) > 1:
|
|
raise ValueError(
|
|
"NL2SQLTool blocked a multi-statement query in read-only mode. "
|
|
"Semicolons are not permitted when allow_dml=False."
|
|
)
|
|
|
|
for stmt in statements:
|
|
self._validate_statement(stmt)
|
|
|
|
def _validate_statement(self, stmt: str) -> None:
|
|
"""Validate a single SQL statement (no semicolons)."""
|
|
command = self._extract_command(stmt)
|
|
|
|
# EXPLAIN ANALYZE / EXPLAIN ANALYSE actually *executes* the underlying
|
|
# query. Resolve the real command so write operations are caught.
|
|
# Handles both space-separated ("EXPLAIN ANALYZE DELETE …") and
|
|
# parenthesized ("EXPLAIN (ANALYZE) DELETE …", "EXPLAIN (ANALYZE, VERBOSE) DELETE …").
|
|
# EXPLAIN ANALYZE actually executes the underlying query — resolve the
|
|
# real command so write operations are caught.
|
|
if command == "EXPLAIN":
|
|
resolved = _resolve_explain_command(stmt)
|
|
if resolved:
|
|
command = resolved
|
|
|
|
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
|
|
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
|
|
if command == "WITH":
|
|
# Check for write commands inside CTE bodies.
|
|
write_found = _detect_writable_cte(stmt)
|
|
if write_found:
|
|
found = write_found
|
|
if not self.allow_dml:
|
|
raise ValueError(
|
|
f"NL2SQLTool is configured in read-only mode and blocked a "
|
|
f"writable CTE containing a '{found}' statement. To allow "
|
|
f"write operations set allow_dml=True or "
|
|
f"CREWAI_NL2SQL_ALLOW_DML=true."
|
|
)
|
|
logger.warning(
|
|
"NL2SQLTool: executing writable CTE with '%s' because allow_dml=True.",
|
|
found,
|
|
)
|
|
return
|
|
|
|
# Check the main query after the CTE definitions.
|
|
main_query = _extract_main_query_after_cte(stmt)
|
|
if main_query:
|
|
main_cmd = main_query.split()[0].upper().rstrip(";")
|
|
if main_cmd in _WRITE_COMMANDS:
|
|
if not self.allow_dml:
|
|
raise ValueError(
|
|
f"NL2SQLTool is configured in read-only mode and blocked a "
|
|
f"'{main_cmd}' statement after a CTE. To allow write "
|
|
f"operations set allow_dml=True or "
|
|
f"CREWAI_NL2SQL_ALLOW_DML=true."
|
|
)
|
|
logger.warning(
|
|
"NL2SQLTool: executing '%s' after CTE because allow_dml=True.",
|
|
main_cmd,
|
|
)
|
|
elif main_cmd not in _READ_ONLY_COMMANDS:
|
|
if not self.allow_dml:
|
|
raise ValueError(
|
|
f"NL2SQLTool blocked an unrecognised SQL command '{main_cmd}' "
|
|
f"after a CTE. Only {sorted(_READ_ONLY_COMMANDS)} are allowed "
|
|
f"in read-only mode."
|
|
)
|
|
return
|
|
|
|
if command in _WRITE_COMMANDS:
|
|
if not self.allow_dml:
|
|
raise ValueError(
|
|
f"NL2SQLTool is configured in read-only mode and blocked a "
|
|
f"'{command}' statement. To allow write operations set "
|
|
f"allow_dml=True or CREWAI_NL2SQL_ALLOW_DML=true."
|
|
)
|
|
logger.warning(
|
|
"NL2SQLTool: executing write statement '%s' because allow_dml=True.",
|
|
command,
|
|
)
|
|
elif command not in _READ_ONLY_COMMANDS:
|
|
# Unknown command — block by default unless DML is explicitly enabled
|
|
if not self.allow_dml:
|
|
raise ValueError(
|
|
f"NL2SQLTool blocked an unrecognised SQL command '{command}'. "
|
|
f"Only {sorted(_READ_ONLY_COMMANDS)} are allowed in read-only "
|
|
f"mode."
|
|
)
|
|
|
|
@staticmethod
|
|
def _extract_command(sql_query: str) -> str:
|
|
"""Return the uppercased first keyword of *sql_query*."""
|
|
stripped = sql_query.strip().lstrip("(")
|
|
first_token = stripped.split()[0] if stripped.split() else ""
|
|
return first_token.upper().rstrip(";")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Schema introspection helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
|
|
return self.execute_sql(
|
|
"SELECT table_name FROM information_schema.tables "
|
|
"WHERE table_schema = 'public';"
|
|
)
|
|
|
|
def _fetch_all_available_columns(
|
|
self, table_name: str
|
|
) -> list[dict[str, Any]] | str:
|
|
"""Fetch columns for *table_name* using a parameterised query.
|
|
|
|
The table name is bound via SQLAlchemy's ``:param`` syntax to prevent
|
|
SQL injection from catalogue values.
|
|
"""
|
|
return self.execute_sql(
|
|
"SELECT column_name, data_type FROM information_schema.columns "
|
|
"WHERE table_name = :table_name",
|
|
params={"table_name": table_name},
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Core execution
|
|
# ------------------------------------------------------------------
|
|
|
|
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
|
try:
|
|
self._validate_query(sql_query)
|
|
data = self.execute_sql(sql_query)
|
|
except ValueError:
|
|
raise
|
|
except Exception as exc:
|
|
data = (
|
|
f"Based on these tables {self.tables} and columns {self.columns}, "
|
|
"you can create SQL queries to retrieve data from the database. "
|
|
f"Get the original request {sql_query} and the error {exc} and "
|
|
"create the correct SQL query."
|
|
)
|
|
|
|
return data
|
|
|
|
def execute_sql(
|
|
self,
|
|
sql_query: str,
|
|
params: dict[str, Any] | None = None,
|
|
) -> list[dict[str, Any]] | str:
|
|
"""Execute *sql_query* and return the results as a list of dicts.
|
|
|
|
Parameters
|
|
----------
|
|
sql_query:
|
|
The SQL statement to run.
|
|
params:
|
|
Optional mapping of bind parameters (e.g. ``{"table_name": "users"}``).
|
|
"""
|
|
if not SQLALCHEMY_AVAILABLE:
|
|
raise ImportError(
|
|
"sqlalchemy is not installed. Please install it with "
|
|
"`pip install crewai-tools[sqlalchemy]`"
|
|
)
|
|
|
|
# Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a
|
|
# commit when allow_dml=True, regardless of statement order.
|
|
_stmts = [s.strip() for s in sql_query.split(";") if s.strip()]
|
|
|
|
def _is_write_stmt(s: str) -> bool:
|
|
cmd = self._extract_command(s)
|
|
if cmd in _WRITE_COMMANDS:
|
|
return True
|
|
if cmd == "EXPLAIN":
|
|
# Resolve the underlying command for EXPLAIN ANALYZE
|
|
resolved = _resolve_explain_command(s)
|
|
if resolved and resolved in _WRITE_COMMANDS:
|
|
return True
|
|
if cmd == "WITH":
|
|
if _detect_writable_cte(s):
|
|
return True
|
|
main_q = _extract_main_query_after_cte(s)
|
|
if main_q:
|
|
return main_q.split()[0].upper().rstrip(";") in _WRITE_COMMANDS
|
|
return False
|
|
|
|
is_write = any(_is_write_stmt(s) for s in _stmts)
|
|
|
|
engine = create_engine(self.db_uri)
|
|
Session = sessionmaker(bind=engine) # noqa: N806
|
|
session = Session()
|
|
try:
|
|
result = session.execute(text(sql_query), params or {})
|
|
|
|
# Only commit when the operation actually mutates state
|
|
if self.allow_dml and is_write:
|
|
session.commit()
|
|
|
|
if result.returns_rows: # type: ignore[attr-defined]
|
|
columns = result.keys()
|
|
return [
|
|
dict(zip(columns, row, strict=False)) for row in result.fetchall()
|
|
]
|
|
return f"Query {sql_query} executed successfully"
|
|
|
|
except Exception as e:
|
|
session.rollback()
|
|
raise e
|
|
|
|
finally:
|
|
session.close()
|