mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: catch write commands in CTE main query + handle whitespace in AS()
- WITH cte AS (SELECT 1) DELETE FROM users now correctly blocked - AS followed by newline/tab/multi-space before ( now detected - execute_sql commit logic updated for both cases - 4 new tests
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
@@ -73,30 +74,52 @@ _CTE_WRITE_INDICATORS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_AS_PAREN_RE = re.compile(r"\bAS\s*\(", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
def _detect_writable_cte(stmt: str) -> str | None:
|
def _detect_writable_cte(stmt: str) -> str | None:
|
||||||
"""Return the first write command inside a CTE body, or None.
|
"""Return the first write command inside a CTE body, or None.
|
||||||
|
|
||||||
Instead of tokenizing the whole statement (which falsely matches column
|
Instead of tokenizing the whole statement (which falsely matches column
|
||||||
names like ``comment``), this walks through parenthesized CTE bodies and
|
names like ``comment``), this walks through parenthesized CTE bodies and
|
||||||
checks only the *first keyword after* an opening ``AS (`` for a write
|
checks only the *first keyword after* an opening ``AS (`` for a write
|
||||||
command.
|
command. Uses a regex to handle any whitespace (spaces, tabs, newlines)
|
||||||
|
between ``AS`` and ``(``.
|
||||||
"""
|
"""
|
||||||
upper = stmt.upper()
|
for m in _AS_PAREN_RE.finditer(stmt):
|
||||||
search_start = 0
|
body = stmt[m.end() :].lstrip()
|
||||||
while True:
|
first_word = body.split()[0].upper().strip("()") if body.split() else ""
|
||||||
# Find the next "AS (" which introduces a CTE body.
|
|
||||||
idx = upper.find("AS (", search_start)
|
|
||||||
if idx == -1:
|
|
||||||
idx = upper.find("AS(", search_start)
|
|
||||||
if idx == -1:
|
|
||||||
break
|
|
||||||
# Jump past "AS (" or "AS("
|
|
||||||
paren_start = upper.index("(", idx) + 1
|
|
||||||
body = upper[paren_start:].lstrip()
|
|
||||||
first_word = body.split()[0].strip("()") if body.split() else ""
|
|
||||||
if first_word in _CTE_WRITE_INDICATORS:
|
if first_word in _CTE_WRITE_INDICATORS:
|
||||||
return first_word
|
return first_word
|
||||||
search_start = paren_start
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_main_query_after_cte(stmt: str) -> str | None:
|
||||||
|
"""Extract the main (outer) query that follows all CTE definitions.
|
||||||
|
|
||||||
|
For ``WITH cte AS (SELECT 1) DELETE FROM users``, returns ``DELETE FROM users``.
|
||||||
|
Returns None if no main query is found after the last CTE body.
|
||||||
|
"""
|
||||||
|
# Walk through balanced parens after each AS( to find the end of CTE bodies.
|
||||||
|
last_cte_end = 0
|
||||||
|
for m in _AS_PAREN_RE.finditer(stmt):
|
||||||
|
# Find the matching closing paren for this CTE body.
|
||||||
|
depth = 1
|
||||||
|
i = m.end()
|
||||||
|
while i < len(stmt) and depth > 0:
|
||||||
|
if stmt[i] == "(":
|
||||||
|
depth += 1
|
||||||
|
elif stmt[i] == ")":
|
||||||
|
depth -= 1
|
||||||
|
i += 1
|
||||||
|
last_cte_end = i
|
||||||
|
|
||||||
|
if last_cte_end > 0:
|
||||||
|
remainder = stmt[last_cte_end:].strip().lstrip(",").strip()
|
||||||
|
# Skip additional CTE definitions (name AS (...))
|
||||||
|
# The remainder after the last CTE closing paren is the main query
|
||||||
|
if remainder:
|
||||||
|
return remainder
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -249,6 +272,7 @@ class NL2SQLTool(BaseTool):
|
|||||||
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
|
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
|
||||||
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
|
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
|
||||||
if command == "WITH":
|
if command == "WITH":
|
||||||
|
# Check for write commands inside CTE bodies.
|
||||||
write_found = _detect_writable_cte(stmt)
|
write_found = _detect_writable_cte(stmt)
|
||||||
if write_found:
|
if write_found:
|
||||||
found = write_found
|
found = write_found
|
||||||
@@ -263,7 +287,24 @@ class NL2SQLTool(BaseTool):
|
|||||||
"NL2SQLTool: executing writable CTE with '%s' because allow_dml=True.",
|
"NL2SQLTool: executing writable CTE with '%s' because allow_dml=True.",
|
||||||
found,
|
found,
|
||||||
)
|
)
|
||||||
# Both read-only and writable-but-permitted CTEs need no further checks.
|
return
|
||||||
|
|
||||||
|
# Check the main query after the CTE definitions.
|
||||||
|
main_query = _extract_main_query_after_cte(stmt)
|
||||||
|
if main_query:
|
||||||
|
main_cmd = main_query.split()[0].upper().rstrip(";")
|
||||||
|
if main_cmd in _WRITE_COMMANDS:
|
||||||
|
if not self.allow_dml:
|
||||||
|
raise ValueError(
|
||||||
|
f"NL2SQLTool is configured in read-only mode and blocked a "
|
||||||
|
f"'{main_cmd}' statement after a CTE. To allow write "
|
||||||
|
f"operations set allow_dml=True or "
|
||||||
|
f"CREWAI_NL2SQL_ALLOW_DML=true."
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"NL2SQLTool: executing '%s' after CTE because allow_dml=True.",
|
||||||
|
main_cmd,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if command in _WRITE_COMMANDS:
|
if command in _WRITE_COMMANDS:
|
||||||
@@ -360,11 +401,20 @@ class NL2SQLTool(BaseTool):
|
|||||||
# Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a
|
# Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a
|
||||||
# commit when allow_dml=True, regardless of statement order.
|
# commit when allow_dml=True, regardless of statement order.
|
||||||
_stmts = [s.strip() for s in sql_query.split(";") if s.strip()]
|
_stmts = [s.strip() for s in sql_query.split(";") if s.strip()]
|
||||||
is_write = any(
|
|
||||||
self._extract_command(s) in _WRITE_COMMANDS
|
def _is_write_stmt(s: str) -> bool:
|
||||||
or (self._extract_command(s) == "WITH" and _detect_writable_cte(s))
|
cmd = self._extract_command(s)
|
||||||
for s in _stmts
|
if cmd in _WRITE_COMMANDS:
|
||||||
)
|
return True
|
||||||
|
if cmd == "WITH":
|
||||||
|
if _detect_writable_cte(s):
|
||||||
|
return True
|
||||||
|
main_q = _extract_main_query_after_cte(s)
|
||||||
|
if main_q:
|
||||||
|
return main_q.split()[0].upper().rstrip(";") in _WRITE_COMMANDS
|
||||||
|
return False
|
||||||
|
|
||||||
|
is_write = any(_is_write_stmt(s) for s in _stmts)
|
||||||
|
|
||||||
engine = create_engine(self.db_uri)
|
engine = create_engine(self.db_uri)
|
||||||
Session = sessionmaker(bind=engine) # noqa: N806
|
Session = sessionmaker(bind=engine) # noqa: N806
|
||||||
|
|||||||
@@ -379,6 +379,38 @@ class TestWritableCTE:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_cte_with_write_main_query_blocked(self):
|
||||||
|
"""WITH cte AS (SELECT 1) DELETE FROM users — main query must be caught."""
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
with pytest.raises(ValueError, match="read-only mode"):
|
||||||
|
tool._validate_query(
|
||||||
|
"WITH cte AS (SELECT 1) DELETE FROM users"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cte_with_write_main_query_allowed_with_dml(self):
|
||||||
|
"""Main query write after CTE should pass when allow_dml=True."""
|
||||||
|
tool = _make_tool(allow_dml=True)
|
||||||
|
tool._validate_query(
|
||||||
|
"WITH cte AS (SELECT id FROM users) INSERT INTO archive SELECT * FROM cte"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cte_with_newline_before_paren_blocked(self):
|
||||||
|
"""AS followed by newline then ( should still detect writable CTE."""
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
with pytest.raises(ValueError, match="read-only mode"):
|
||||||
|
tool._validate_query(
|
||||||
|
"WITH cte AS\n(DELETE FROM users RETURNING *) SELECT * FROM cte"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_cte_with_tab_before_paren_blocked(self):
|
||||||
|
"""AS followed by tab then ( should still detect writable CTE."""
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
with pytest.raises(ValueError, match="read-only mode"):
|
||||||
|
tool._validate_query(
|
||||||
|
"WITH cte AS\t(DELETE FROM users RETURNING *) SELECT * FROM cte"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestExplainAnalyze:
|
class TestExplainAnalyze:
|
||||||
def test_explain_analyze_delete_blocked_in_read_only(self):
|
def test_explain_analyze_delete_blocked_in_read_only(self):
|
||||||
"""EXPLAIN ANALYZE DELETE actually runs the delete — block it."""
|
"""EXPLAIN ANALYZE DELETE actually runs the delete — block it."""
|
||||||
|
|||||||
Reference in New Issue
Block a user