diff --git a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py index 3d6cc7359..561042461 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py @@ -1,5 +1,6 @@ import logging import os +import re from typing import Any @@ -73,30 +74,52 @@ _CTE_WRITE_INDICATORS = { } +_AS_PAREN_RE = re.compile(r"\bAS\s*\(", re.IGNORECASE) + + def _detect_writable_cte(stmt: str) -> str | None: """Return the first write command inside a CTE body, or None. Instead of tokenizing the whole statement (which falsely matches column names like ``comment``), this walks through parenthesized CTE bodies and checks only the *first keyword after* an opening ``AS (`` for a write - command. + command. Uses a regex to handle any whitespace (spaces, tabs, newlines) + between ``AS`` and ``(``. """ - upper = stmt.upper() - search_start = 0 - while True: - # Find the next "AS (" which introduces a CTE body. - idx = upper.find("AS (", search_start) - if idx == -1: - idx = upper.find("AS(", search_start) - if idx == -1: - break - # Jump past "AS (" or "AS(" - paren_start = upper.index("(", idx) + 1 - body = upper[paren_start:].lstrip() - first_word = body.split()[0].strip("()") if body.split() else "" + for m in _AS_PAREN_RE.finditer(stmt): + body = stmt[m.end() :].lstrip() + first_word = body.split()[0].upper().strip("()") if body.split() else "" if first_word in _CTE_WRITE_INDICATORS: return first_word - search_start = paren_start + return None + + +def _extract_main_query_after_cte(stmt: str) -> str | None: + """Extract the main (outer) query that follows all CTE definitions. + + For ``WITH cte AS (SELECT 1) DELETE FROM users``, returns ``DELETE FROM users``. + Returns None if no main query is found after the last CTE body. + """ + # Walk through balanced parens after each AS( to find the end of CTE bodies. + last_cte_end = 0 + for m in _AS_PAREN_RE.finditer(stmt): + # Find the matching closing paren for this CTE body. + depth = 1 + i = m.end() + while i < len(stmt) and depth > 0: + if stmt[i] == "(": + depth += 1 + elif stmt[i] == ")": + depth -= 1 + i += 1 + last_cte_end = i + + if last_cte_end > 0: + remainder = stmt[last_cte_end:].strip().lstrip(",").strip() + # Skip additional CTE definitions (name AS (...)) + # The remainder after the last CTE closing paren is the main query + if remainder: + return remainder return None @@ -249,6 +272,7 @@ class NL2SQLTool(BaseTool): # WITH starts a CTE. Read-only CTEs are fine; writable CTEs # (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode. if command == "WITH": + # Check for write commands inside CTE bodies. write_found = _detect_writable_cte(stmt) if write_found: found = write_found @@ -263,7 +287,24 @@ class NL2SQLTool(BaseTool): "NL2SQLTool: executing writable CTE with '%s' because allow_dml=True.", found, ) - # Both read-only and writable-but-permitted CTEs need no further checks. + return + + # Check the main query after the CTE definitions. + main_query = _extract_main_query_after_cte(stmt) + if main_query: + main_cmd = main_query.split()[0].upper().rstrip(";") + if main_cmd in _WRITE_COMMANDS: + if not self.allow_dml: + raise ValueError( + f"NL2SQLTool is configured in read-only mode and blocked a " + f"'{main_cmd}' statement after a CTE. To allow write " + f"operations set allow_dml=True or " + f"CREWAI_NL2SQL_ALLOW_DML=true." + ) + logger.warning( + "NL2SQLTool: executing '%s' after CTE because allow_dml=True.", + main_cmd, + ) return if command in _WRITE_COMMANDS: @@ -360,11 +401,20 @@ class NL2SQLTool(BaseTool): # Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a # commit when allow_dml=True, regardless of statement order. _stmts = [s.strip() for s in sql_query.split(";") if s.strip()] - is_write = any( - self._extract_command(s) in _WRITE_COMMANDS - or (self._extract_command(s) == "WITH" and _detect_writable_cte(s)) - for s in _stmts - ) + + def _is_write_stmt(s: str) -> bool: + cmd = self._extract_command(s) + if cmd in _WRITE_COMMANDS: + return True + if cmd == "WITH": + if _detect_writable_cte(s): + return True + main_q = _extract_main_query_after_cte(s) + if main_q: + return main_q.split()[0].upper().rstrip(";") in _WRITE_COMMANDS + return False + + is_write = any(_is_write_stmt(s) for s in _stmts) engine = create_engine(self.db_uri) Session = sessionmaker(bind=engine) # noqa: N806 diff --git a/lib/crewai-tools/tests/tools/test_nl2sql_security.py b/lib/crewai-tools/tests/tools/test_nl2sql_security.py index d5db86c91..0a2ea34a1 100644 --- a/lib/crewai-tools/tests/tools/test_nl2sql_security.py +++ b/lib/crewai-tools/tests/tools/test_nl2sql_security.py @@ -379,6 +379,38 @@ class TestWritableCTE: # --------------------------------------------------------------------------- + def test_cte_with_write_main_query_blocked(self): + """WITH cte AS (SELECT 1) DELETE FROM users — main query must be caught.""" + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query( + "WITH cte AS (SELECT 1) DELETE FROM users" + ) + + def test_cte_with_write_main_query_allowed_with_dml(self): + """Main query write after CTE should pass when allow_dml=True.""" + tool = _make_tool(allow_dml=True) + tool._validate_query( + "WITH cte AS (SELECT id FROM users) INSERT INTO archive SELECT * FROM cte" + ) + + def test_cte_with_newline_before_paren_blocked(self): + """AS followed by newline then ( should still detect writable CTE.""" + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query( + "WITH cte AS\n(DELETE FROM users RETURNING *) SELECT * FROM cte" + ) + + def test_cte_with_tab_before_paren_blocked(self): + """AS followed by tab then ( should still detect writable CTE.""" + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query( + "WITH cte AS\t(DELETE FROM users RETURNING *) SELECT * FROM cte" + ) + + class TestExplainAnalyze: def test_explain_analyze_delete_blocked_in_read_only(self): """EXPLAIN ANALYZE DELETE actually runs the delete — block it."""