fix: smarter CTE write detection, fix commit logic for writable CTEs

- Replace naive token-set matching with positional AS() body inspection
  to avoid false positives on column names like 'comment', 'set', 'reset'
- Fix execute_sql commit logic to detect writable CTEs (WITH + DELETE/INSERT)
  not just top-level write commands
- Add tests for false positive cases and writable CTE commit behavior
- Format nl2sql_tool.py to pass ruff format check
This commit is contained in:
Alex
2026-04-07 09:14:58 -07:00
parent 0ca5e812cc
commit 993441f182
2 changed files with 90 additions and 8 deletions

View File

@@ -58,6 +58,48 @@ _WRITE_COMMANDS = {
}
# Subset of write commands that can realistically appear *inside* a CTE body.
# Narrower than _WRITE_COMMANDS to avoid false positives on identifiers like
# ``comment``, ``set``, or ``reset`` which are common column/table names.
_CTE_WRITE_INDICATORS = {
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"MERGE",
}
def _detect_writable_cte(stmt: str) -> str | None:
"""Return the first write command inside a CTE body, or None.
Instead of tokenizing the whole statement (which falsely matches column
names like ``comment``), this walks through parenthesized CTE bodies and
checks only the *first keyword after* an opening ``AS (`` for a write
command.
"""
upper = stmt.upper()
search_start = 0
while True:
# Find the next "AS (" which introduces a CTE body.
idx = upper.find("AS (", search_start)
if idx == -1:
idx = upper.find("AS(", search_start)
if idx == -1:
break
# Jump past "AS (" or "AS("
paren_start = upper.index("(", idx) + 1
body = upper[paren_start:].lstrip()
first_word = body.split()[0].strip("()") if body.split() else ""
if first_word in _CTE_WRITE_INDICATORS:
return first_word
search_start = paren_start
return None
class NL2SQLToolInput(BaseModel):
sql_query: str = Field(
title="SQL Query",
@@ -181,7 +223,7 @@ class NL2SQLTool(BaseTool):
# Handles both space-separated ("EXPLAIN ANALYZE DELETE …") and
# parenthesized ("EXPLAIN (ANALYZE) DELETE …", "EXPLAIN (ANALYZE, VERBOSE) DELETE …").
if command == "EXPLAIN":
rest = stmt.strip()[len("EXPLAIN"):].strip()
rest = stmt.strip()[len("EXPLAIN") :].strip()
analyze_found = False
if rest.startswith("("):
@@ -193,13 +235,13 @@ class NL2SQLTool(BaseTool):
opt.strip() in ("ANALYZE", "ANALYSE")
for opt in options_str.split(",")
)
rest = rest[close + 1:].strip()
rest = rest[close + 1 :].strip()
else:
# Space-separated: EXPLAIN ANALYZE <stmt>
first_opt = rest.split()[0].upper().rstrip(";") if rest.split() else ""
if first_opt in ("ANALYZE", "ANALYSE"):
analyze_found = True
rest = rest[len(first_opt):].strip()
rest = rest[len(first_opt) :].strip()
if analyze_found and rest:
command = rest.split()[0].upper().rstrip(";")
@@ -207,10 +249,9 @@ class NL2SQLTool(BaseTool):
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
if command == "WITH":
tokens_upper = {t.upper().strip("();,") for t in stmt.split()}
write_found = tokens_upper & _WRITE_COMMANDS
write_found = _detect_writable_cte(stmt)
if write_found:
found = next(iter(write_found))
found = write_found
if not self.allow_dml:
raise ValueError(
f"NL2SQLTool is configured in read-only mode and blocked a "
@@ -319,7 +360,11 @@ class NL2SQLTool(BaseTool):
# Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a
# commit when allow_dml=True, regardless of statement order.
_stmts = [s.strip() for s in sql_query.split(";") if s.strip()]
is_write = any(self._extract_command(s) in _WRITE_COMMANDS for s in _stmts)
is_write = any(
self._extract_command(s) in _WRITE_COMMANDS
or (self._extract_command(s) == "WITH" and _detect_writable_cte(s))
for s in _stmts
)
engine = create_engine(self.db_uri)
Session = sessionmaker(bind=engine) # noqa: N806

View File

@@ -358,6 +358,21 @@ class TestWritableCTE:
# No write commands in the CTE body — must pass
tool._validate_query("WITH cte AS (SELECT id FROM users) SELECT * FROM cte")
def test_cte_with_comment_column_not_false_positive(self):
"""Column named 'comment' should NOT trigger writable CTE detection."""
tool = _make_tool(allow_dml=False)
# 'comment' is a column name, not a SQL command
tool._validate_query(
"WITH cte AS (SELECT comment FROM posts) SELECT * FROM cte"
)
def test_cte_with_set_column_not_false_positive(self):
"""Column named 'set' should NOT trigger writable CTE detection."""
tool = _make_tool(allow_dml=False)
tool._validate_query(
"WITH cte AS (SELECT set, reset FROM config) SELECT * FROM cte"
)
# ---------------------------------------------------------------------------
# EXPLAIN ANALYZE executes the underlying query
@@ -461,7 +476,29 @@ class TestMultiStatementCommit:
):
tool.execute_sql("SELECT 1; SELECT 2")
mock_session.commit.assert_not_called()
def test_writable_cte_triggers_commit(self):
"""WITH d AS (DELETE ...) must trigger commit when allow_dml=True."""
tool = _make_tool(allow_dml=True)
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.returns_rows = True
mock_result.keys.return_value = ["id"]
mock_result.fetchall.return_value = [(1,)]
mock_session.execute.return_value = mock_result
mock_session_cls = MagicMock(return_value=mock_session)
with (
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
patch(
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
return_value=mock_session_cls,
),
):
tool.execute_sql(
"WITH d AS (DELETE FROM users RETURNING *) SELECT * FROM d"
)
mock_session.commit.assert_called_once()
# ---------------------------------------------------------------------------