mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-08 03:58:23 +00:00
fix: catch write commands in CTE main query + handle whitespace in AS()
- WITH cte AS (SELECT 1) DELETE FROM users now correctly blocked - AS followed by newline/tab/multi-space before ( now detected - execute_sql commit logic updated for both cases - 4 new tests
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -73,30 +74,52 @@ _CTE_WRITE_INDICATORS = {
|
||||
}
|
||||
|
||||
|
||||
_AS_PAREN_RE = re.compile(r"\bAS\s*\(", re.IGNORECASE)
|
||||
|
||||
|
||||
def _detect_writable_cte(stmt: str) -> str | None:
|
||||
"""Return the first write command inside a CTE body, or None.
|
||||
|
||||
Instead of tokenizing the whole statement (which falsely matches column
|
||||
names like ``comment``), this walks through parenthesized CTE bodies and
|
||||
checks only the *first keyword after* an opening ``AS (`` for a write
|
||||
command.
|
||||
command. Uses a regex to handle any whitespace (spaces, tabs, newlines)
|
||||
between ``AS`` and ``(``.
|
||||
"""
|
||||
upper = stmt.upper()
|
||||
search_start = 0
|
||||
while True:
|
||||
# Find the next "AS (" which introduces a CTE body.
|
||||
idx = upper.find("AS (", search_start)
|
||||
if idx == -1:
|
||||
idx = upper.find("AS(", search_start)
|
||||
if idx == -1:
|
||||
break
|
||||
# Jump past "AS (" or "AS("
|
||||
paren_start = upper.index("(", idx) + 1
|
||||
body = upper[paren_start:].lstrip()
|
||||
first_word = body.split()[0].strip("()") if body.split() else ""
|
||||
for m in _AS_PAREN_RE.finditer(stmt):
|
||||
body = stmt[m.end() :].lstrip()
|
||||
first_word = body.split()[0].upper().strip("()") if body.split() else ""
|
||||
if first_word in _CTE_WRITE_INDICATORS:
|
||||
return first_word
|
||||
search_start = paren_start
|
||||
return None
|
||||
|
||||
|
||||
def _extract_main_query_after_cte(stmt: str) -> str | None:
|
||||
"""Extract the main (outer) query that follows all CTE definitions.
|
||||
|
||||
For ``WITH cte AS (SELECT 1) DELETE FROM users``, returns ``DELETE FROM users``.
|
||||
Returns None if no main query is found after the last CTE body.
|
||||
"""
|
||||
# Walk through balanced parens after each AS( to find the end of CTE bodies.
|
||||
last_cte_end = 0
|
||||
for m in _AS_PAREN_RE.finditer(stmt):
|
||||
# Find the matching closing paren for this CTE body.
|
||||
depth = 1
|
||||
i = m.end()
|
||||
while i < len(stmt) and depth > 0:
|
||||
if stmt[i] == "(":
|
||||
depth += 1
|
||||
elif stmt[i] == ")":
|
||||
depth -= 1
|
||||
i += 1
|
||||
last_cte_end = i
|
||||
|
||||
if last_cte_end > 0:
|
||||
remainder = stmt[last_cte_end:].strip().lstrip(",").strip()
|
||||
# Skip additional CTE definitions (name AS (...))
|
||||
# The remainder after the last CTE closing paren is the main query
|
||||
if remainder:
|
||||
return remainder
|
||||
return None
|
||||
|
||||
|
||||
@@ -249,6 +272,7 @@ class NL2SQLTool(BaseTool):
|
||||
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
|
||||
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
|
||||
if command == "WITH":
|
||||
# Check for write commands inside CTE bodies.
|
||||
write_found = _detect_writable_cte(stmt)
|
||||
if write_found:
|
||||
found = write_found
|
||||
@@ -263,7 +287,24 @@ class NL2SQLTool(BaseTool):
|
||||
"NL2SQLTool: executing writable CTE with '%s' because allow_dml=True.",
|
||||
found,
|
||||
)
|
||||
# Both read-only and writable-but-permitted CTEs need no further checks.
|
||||
return
|
||||
|
||||
# Check the main query after the CTE definitions.
|
||||
main_query = _extract_main_query_after_cte(stmt)
|
||||
if main_query:
|
||||
main_cmd = main_query.split()[0].upper().rstrip(";")
|
||||
if main_cmd in _WRITE_COMMANDS:
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool is configured in read-only mode and blocked a "
|
||||
f"'{main_cmd}' statement after a CTE. To allow write "
|
||||
f"operations set allow_dml=True or "
|
||||
f"CREWAI_NL2SQL_ALLOW_DML=true."
|
||||
)
|
||||
logger.warning(
|
||||
"NL2SQLTool: executing '%s' after CTE because allow_dml=True.",
|
||||
main_cmd,
|
||||
)
|
||||
return
|
||||
|
||||
if command in _WRITE_COMMANDS:
|
||||
@@ -360,11 +401,20 @@ class NL2SQLTool(BaseTool):
|
||||
# Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a
|
||||
# commit when allow_dml=True, regardless of statement order.
|
||||
_stmts = [s.strip() for s in sql_query.split(";") if s.strip()]
|
||||
is_write = any(
|
||||
self._extract_command(s) in _WRITE_COMMANDS
|
||||
or (self._extract_command(s) == "WITH" and _detect_writable_cte(s))
|
||||
for s in _stmts
|
||||
)
|
||||
|
||||
def _is_write_stmt(s: str) -> bool:
|
||||
cmd = self._extract_command(s)
|
||||
if cmd in _WRITE_COMMANDS:
|
||||
return True
|
||||
if cmd == "WITH":
|
||||
if _detect_writable_cte(s):
|
||||
return True
|
||||
main_q = _extract_main_query_after_cte(s)
|
||||
if main_q:
|
||||
return main_q.split()[0].upper().rstrip(";") in _WRITE_COMMANDS
|
||||
return False
|
||||
|
||||
is_write = any(_is_write_stmt(s) for s in _stmts)
|
||||
|
||||
engine = create_engine(self.db_uri)
|
||||
Session = sessionmaker(bind=engine) # noqa: N806
|
||||
|
||||
@@ -379,6 +379,38 @@ class TestWritableCTE:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_cte_with_write_main_query_blocked(self):
|
||||
"""WITH cte AS (SELECT 1) DELETE FROM users — main query must be caught."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT 1) DELETE FROM users"
|
||||
)
|
||||
|
||||
def test_cte_with_write_main_query_allowed_with_dml(self):
|
||||
"""Main query write after CTE should pass when allow_dml=True."""
|
||||
tool = _make_tool(allow_dml=True)
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT id FROM users) INSERT INTO archive SELECT * FROM cte"
|
||||
)
|
||||
|
||||
def test_cte_with_newline_before_paren_blocked(self):
|
||||
"""AS followed by newline then ( should still detect writable CTE."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH cte AS\n(DELETE FROM users RETURNING *) SELECT * FROM cte"
|
||||
)
|
||||
|
||||
def test_cte_with_tab_before_paren_blocked(self):
|
||||
"""AS followed by tab then ( should still detect writable CTE."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(
|
||||
"WITH cte AS\t(DELETE FROM users RETURNING *) SELECT * FROM cte"
|
||||
)
|
||||
|
||||
|
||||
class TestExplainAnalyze:
|
||||
def test_explain_analyze_delete_blocked_in_read_only(self):
|
||||
"""EXPLAIN ANALYZE DELETE actually runs the delete — block it."""
|
||||
|
||||
Reference in New Issue
Block a user