fix: catch write commands in CTE main query + handle whitespace in AS()

- WITH cte AS (SELECT 1) DELETE FROM users now correctly blocked
- AS followed by newline/tab/multi-space before ( now detected
- execute_sql commit logic updated for both cases
- 4 new tests
This commit is contained in:
Alex
2026-04-07 09:27:48 -07:00
parent 993441f182
commit 88ffeb327b
2 changed files with 103 additions and 21 deletions

View File

@@ -1,5 +1,6 @@
import logging
import os
import re
from typing import Any
@@ -73,30 +74,52 @@ _CTE_WRITE_INDICATORS = {
}
_AS_PAREN_RE = re.compile(r"\bAS\s*\(", re.IGNORECASE)
def _detect_writable_cte(stmt: str) -> str | None:
"""Return the first write command inside a CTE body, or None.
Instead of tokenizing the whole statement (which falsely matches column
names like ``comment``), this walks through parenthesized CTE bodies and
checks only the *first keyword after* an opening ``AS (`` for a write
command.
command. Uses a regex to handle any whitespace (spaces, tabs, newlines)
between ``AS`` and ``(``.
"""
upper = stmt.upper()
search_start = 0
while True:
# Find the next "AS (" which introduces a CTE body.
idx = upper.find("AS (", search_start)
if idx == -1:
idx = upper.find("AS(", search_start)
if idx == -1:
break
# Jump past "AS (" or "AS("
paren_start = upper.index("(", idx) + 1
body = upper[paren_start:].lstrip()
first_word = body.split()[0].strip("()") if body.split() else ""
for m in _AS_PAREN_RE.finditer(stmt):
body = stmt[m.end() :].lstrip()
first_word = body.split()[0].upper().strip("()") if body.split() else ""
if first_word in _CTE_WRITE_INDICATORS:
return first_word
search_start = paren_start
return None
def _extract_main_query_after_cte(stmt: str) -> str | None:
"""Extract the main (outer) query that follows all CTE definitions.
For ``WITH cte AS (SELECT 1) DELETE FROM users``, returns ``DELETE FROM users``.
Returns None if no main query is found after the last CTE body.
"""
# Walk through balanced parens after each AS( to find the end of CTE bodies.
last_cte_end = 0
for m in _AS_PAREN_RE.finditer(stmt):
# Find the matching closing paren for this CTE body.
depth = 1
i = m.end()
while i < len(stmt) and depth > 0:
if stmt[i] == "(":
depth += 1
elif stmt[i] == ")":
depth -= 1
i += 1
last_cte_end = i
if last_cte_end > 0:
remainder = stmt[last_cte_end:].strip().lstrip(",").strip()
# Skip additional CTE definitions (name AS (...))
# The remainder after the last CTE closing paren is the main query
if remainder:
return remainder
return None
@@ -249,6 +272,7 @@ class NL2SQLTool(BaseTool):
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
if command == "WITH":
# Check for write commands inside CTE bodies.
write_found = _detect_writable_cte(stmt)
if write_found:
found = write_found
@@ -263,7 +287,24 @@ class NL2SQLTool(BaseTool):
"NL2SQLTool: executing writable CTE with '%s' because allow_dml=True.",
found,
)
# Both read-only and writable-but-permitted CTEs need no further checks.
return
# Check the main query after the CTE definitions.
main_query = _extract_main_query_after_cte(stmt)
if main_query:
main_cmd = main_query.split()[0].upper().rstrip(";")
if main_cmd in _WRITE_COMMANDS:
if not self.allow_dml:
raise ValueError(
f"NL2SQLTool is configured in read-only mode and blocked a "
f"'{main_cmd}' statement after a CTE. To allow write "
f"operations set allow_dml=True or "
f"CREWAI_NL2SQL_ALLOW_DML=true."
)
logger.warning(
"NL2SQLTool: executing '%s' after CTE because allow_dml=True.",
main_cmd,
)
return
if command in _WRITE_COMMANDS:
@@ -360,11 +401,20 @@ class NL2SQLTool(BaseTool):
# Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a
# commit when allow_dml=True, regardless of statement order.
_stmts = [s.strip() for s in sql_query.split(";") if s.strip()]
is_write = any(
self._extract_command(s) in _WRITE_COMMANDS
or (self._extract_command(s) == "WITH" and _detect_writable_cte(s))
for s in _stmts
)
def _is_write_stmt(s: str) -> bool:
cmd = self._extract_command(s)
if cmd in _WRITE_COMMANDS:
return True
if cmd == "WITH":
if _detect_writable_cte(s):
return True
main_q = _extract_main_query_after_cte(s)
if main_q:
return main_q.split()[0].upper().rstrip(";") in _WRITE_COMMANDS
return False
is_write = any(_is_write_stmt(s) for s in _stmts)
engine = create_engine(self.db_uri)
Session = sessionmaker(bind=engine) # noqa: N806

View File

@@ -379,6 +379,38 @@ class TestWritableCTE:
# ---------------------------------------------------------------------------
def test_cte_with_write_main_query_blocked(self):
"""WITH cte AS (SELECT 1) DELETE FROM users — main query must be caught."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query(
"WITH cte AS (SELECT 1) DELETE FROM users"
)
def test_cte_with_write_main_query_allowed_with_dml(self):
"""Main query write after CTE should pass when allow_dml=True."""
tool = _make_tool(allow_dml=True)
tool._validate_query(
"WITH cte AS (SELECT id FROM users) INSERT INTO archive SELECT * FROM cte"
)
def test_cte_with_newline_before_paren_blocked(self):
"""AS followed by newline then ( should still detect writable CTE."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query(
"WITH cte AS\n(DELETE FROM users RETURNING *) SELECT * FROM cte"
)
def test_cte_with_tab_before_paren_blocked(self):
"""AS followed by tab then ( should still detect writable CTE."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query(
"WITH cte AS\t(DELETE FROM users RETURNING *) SELECT * FROM cte"
)
class TestExplainAnalyze:
def test_explain_analyze_delete_blocked_in_read_only(self):
"""EXPLAIN ANALYZE DELETE actually runs the delete — block it."""