mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: smarter CTE write detection, fix commit logic for writable CTEs
- Replace naive token-set matching with positional AS() body inspection to avoid false positives on column names like 'comment', 'set', 'reset' - Fix execute_sql commit logic to detect writable CTEs (WITH + DELETE/INSERT) not just top-level write commands - Add tests for false positive cases and writable CTE commit behavior - Format nl2sql_tool.py to pass ruff format check
This commit is contained in:
@@ -58,6 +58,48 @@ _WRITE_COMMANDS = {
|
||||
}
|
||||
|
||||
|
||||
# Subset of write commands that can realistically appear *inside* a CTE body.
|
||||
# Narrower than _WRITE_COMMANDS to avoid false positives on identifiers like
|
||||
# ``comment``, ``set``, or ``reset`` which are common column/table names.
|
||||
_CTE_WRITE_INDICATORS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"MERGE",
|
||||
}
|
||||
|
||||
|
||||
def _detect_writable_cte(stmt: str) -> str | None:
|
||||
"""Return the first write command inside a CTE body, or None.
|
||||
|
||||
Instead of tokenizing the whole statement (which falsely matches column
|
||||
names like ``comment``), this walks through parenthesized CTE bodies and
|
||||
checks only the *first keyword after* an opening ``AS (`` for a write
|
||||
command.
|
||||
"""
|
||||
upper = stmt.upper()
|
||||
search_start = 0
|
||||
while True:
|
||||
# Find the next "AS (" which introduces a CTE body.
|
||||
idx = upper.find("AS (", search_start)
|
||||
if idx == -1:
|
||||
idx = upper.find("AS(", search_start)
|
||||
if idx == -1:
|
||||
break
|
||||
# Jump past "AS (" or "AS("
|
||||
paren_start = upper.index("(", idx) + 1
|
||||
body = upper[paren_start:].lstrip()
|
||||
first_word = body.split()[0].strip("()") if body.split() else ""
|
||||
if first_word in _CTE_WRITE_INDICATORS:
|
||||
return first_word
|
||||
search_start = paren_start
|
||||
return None
|
||||
|
||||
|
||||
class NL2SQLToolInput(BaseModel):
|
||||
sql_query: str = Field(
|
||||
title="SQL Query",
|
||||
@@ -181,7 +223,7 @@ class NL2SQLTool(BaseTool):
|
||||
# Handles both space-separated ("EXPLAIN ANALYZE DELETE …") and
|
||||
# parenthesized ("EXPLAIN (ANALYZE) DELETE …", "EXPLAIN (ANALYZE, VERBOSE) DELETE …").
|
||||
if command == "EXPLAIN":
|
||||
rest = stmt.strip()[len("EXPLAIN"):].strip()
|
||||
rest = stmt.strip()[len("EXPLAIN") :].strip()
|
||||
analyze_found = False
|
||||
|
||||
if rest.startswith("("):
|
||||
@@ -193,13 +235,13 @@ class NL2SQLTool(BaseTool):
|
||||
opt.strip() in ("ANALYZE", "ANALYSE")
|
||||
for opt in options_str.split(",")
|
||||
)
|
||||
rest = rest[close + 1:].strip()
|
||||
rest = rest[close + 1 :].strip()
|
||||
else:
|
||||
# Space-separated: EXPLAIN ANALYZE <stmt>
|
||||
first_opt = rest.split()[0].upper().rstrip(";") if rest.split() else ""
|
||||
if first_opt in ("ANALYZE", "ANALYSE"):
|
||||
analyze_found = True
|
||||
rest = rest[len(first_opt):].strip()
|
||||
rest = rest[len(first_opt) :].strip()
|
||||
|
||||
if analyze_found and rest:
|
||||
command = rest.split()[0].upper().rstrip(";")
|
||||
@@ -207,10 +249,9 @@ class NL2SQLTool(BaseTool):
|
||||
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
|
||||
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
|
||||
if command == "WITH":
|
||||
tokens_upper = {t.upper().strip("();,") for t in stmt.split()}
|
||||
write_found = tokens_upper & _WRITE_COMMANDS
|
||||
write_found = _detect_writable_cte(stmt)
|
||||
if write_found:
|
||||
found = next(iter(write_found))
|
||||
found = write_found
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool is configured in read-only mode and blocked a "
|
||||
@@ -319,7 +360,11 @@ class NL2SQLTool(BaseTool):
|
||||
# Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a
|
||||
# commit when allow_dml=True, regardless of statement order.
|
||||
_stmts = [s.strip() for s in sql_query.split(";") if s.strip()]
|
||||
is_write = any(self._extract_command(s) in _WRITE_COMMANDS for s in _stmts)
|
||||
is_write = any(
|
||||
self._extract_command(s) in _WRITE_COMMANDS
|
||||
or (self._extract_command(s) == "WITH" and _detect_writable_cte(s))
|
||||
for s in _stmts
|
||||
)
|
||||
|
||||
engine = create_engine(self.db_uri)
|
||||
Session = sessionmaker(bind=engine) # noqa: N806
|
||||
|
||||
@@ -358,6 +358,21 @@ class TestWritableCTE:
|
||||
# No write commands in the CTE body — must pass
|
||||
tool._validate_query("WITH cte AS (SELECT id FROM users) SELECT * FROM cte")
|
||||
|
||||
def test_cte_with_comment_column_not_false_positive(self):
|
||||
"""Column named 'comment' should NOT trigger writable CTE detection."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
# 'comment' is a column name, not a SQL command
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT comment FROM posts) SELECT * FROM cte"
|
||||
)
|
||||
|
||||
def test_cte_with_set_column_not_false_positive(self):
|
||||
"""Column named 'set' should NOT trigger writable CTE detection."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
tool._validate_query(
|
||||
"WITH cte AS (SELECT set, reset FROM config) SELECT * FROM cte"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EXPLAIN ANALYZE executes the underlying query
|
||||
@@ -461,7 +476,29 @@ class TestMultiStatementCommit:
|
||||
):
|
||||
tool.execute_sql("SELECT 1; SELECT 2")
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
def test_writable_cte_triggers_commit(self):
|
||||
"""WITH d AS (DELETE ...) must trigger commit when allow_dml=True."""
|
||||
tool = _make_tool(allow_dml=True)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returns_rows = True
|
||||
mock_result.keys.return_value = ["id"]
|
||||
mock_result.fetchall.return_value = [(1,)]
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session_cls = MagicMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||
patch(
|
||||
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||
return_value=mock_session_cls,
|
||||
),
|
||||
):
|
||||
tool.execute_sql(
|
||||
"WITH d AS (DELETE FROM users RETURNING *) SELECT * FROM d"
|
||||
)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user