diff --git a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py index 54024ba71..3d6cc7359 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py @@ -58,6 +58,48 @@ _WRITE_COMMANDS = { } +# Subset of write commands that can realistically appear *inside* a CTE body. +# Narrower than _WRITE_COMMANDS to avoid false positives on identifiers like +# ``comment``, ``set``, or ``reset`` which are common column/table names. +_CTE_WRITE_INDICATORS = { + "INSERT", + "UPDATE", + "DELETE", + "DROP", + "ALTER", + "CREATE", + "TRUNCATE", + "MERGE", +} + + +def _detect_writable_cte(stmt: str) -> str | None: + """Return the first write command inside a CTE body, or None. + + Instead of tokenizing the whole statement (which falsely matches column + names like ``comment``), this walks through parenthesized CTE bodies and + checks only the *first keyword after* an opening ``AS (`` for a write + command. + """ + upper = stmt.upper() + search_start = 0 + while True: + # Find the next "AS (" which introduces a CTE body. + idx = upper.find("AS (", search_start) + if idx == -1: + idx = upper.find("AS(", search_start) + if idx == -1: + break + # Jump past "AS (" or "AS(" + paren_start = upper.index("(", idx) + 1 + body = upper[paren_start:].lstrip() + first_word = body.split()[0].strip("()") if body.split() else "" + if first_word in _CTE_WRITE_INDICATORS: + return first_word + search_start = paren_start + return None + + class NL2SQLToolInput(BaseModel): sql_query: str = Field( title="SQL Query", @@ -181,7 +223,7 @@ class NL2SQLTool(BaseTool): # Handles both space-separated ("EXPLAIN ANALYZE DELETE …") and # parenthesized ("EXPLAIN (ANALYZE) DELETE …", "EXPLAIN (ANALYZE, VERBOSE) DELETE …"). if command == "EXPLAIN": - rest = stmt.strip()[len("EXPLAIN"):].strip() + rest = stmt.strip()[len("EXPLAIN") :].strip() analyze_found = False if rest.startswith("("): @@ -193,13 +235,13 @@ class NL2SQLTool(BaseTool): opt.strip() in ("ANALYZE", "ANALYSE") for opt in options_str.split(",") ) - rest = rest[close + 1:].strip() + rest = rest[close + 1 :].strip() else: # Space-separated: EXPLAIN ANALYZE first_opt = rest.split()[0].upper().rstrip(";") if rest.split() else "" if first_opt in ("ANALYZE", "ANALYSE"): analyze_found = True - rest = rest[len(first_opt):].strip() + rest = rest[len(first_opt) :].strip() if analyze_found and rest: command = rest.split()[0].upper().rstrip(";") @@ -207,10 +249,9 @@ class NL2SQLTool(BaseTool): # WITH starts a CTE. Read-only CTEs are fine; writable CTEs # (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode. if command == "WITH": - tokens_upper = {t.upper().strip("();,") for t in stmt.split()} - write_found = tokens_upper & _WRITE_COMMANDS + write_found = _detect_writable_cte(stmt) if write_found: - found = next(iter(write_found)) + found = write_found if not self.allow_dml: raise ValueError( f"NL2SQLTool is configured in read-only mode and blocked a " @@ -319,7 +360,11 @@ class NL2SQLTool(BaseTool): # Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a # commit when allow_dml=True, regardless of statement order. _stmts = [s.strip() for s in sql_query.split(";") if s.strip()] - is_write = any(self._extract_command(s) in _WRITE_COMMANDS for s in _stmts) + is_write = any( + self._extract_command(s) in _WRITE_COMMANDS + or (self._extract_command(s) == "WITH" and _detect_writable_cte(s)) + for s in _stmts + ) engine = create_engine(self.db_uri) Session = sessionmaker(bind=engine) # noqa: N806 diff --git a/lib/crewai-tools/tests/tools/test_nl2sql_security.py b/lib/crewai-tools/tests/tools/test_nl2sql_security.py index 614c82d6f..d5db86c91 100644 --- a/lib/crewai-tools/tests/tools/test_nl2sql_security.py +++ b/lib/crewai-tools/tests/tools/test_nl2sql_security.py @@ -358,6 +358,21 @@ class TestWritableCTE: # No write commands in the CTE body — must pass tool._validate_query("WITH cte AS (SELECT id FROM users) SELECT * FROM cte") + def test_cte_with_comment_column_not_false_positive(self): + """Column named 'comment' should NOT trigger writable CTE detection.""" + tool = _make_tool(allow_dml=False) + # 'comment' is a column name, not a SQL command + tool._validate_query( + "WITH cte AS (SELECT comment FROM posts) SELECT * FROM cte" + ) + + def test_cte_with_set_column_not_false_positive(self): + """Column named 'set' should NOT trigger writable CTE detection.""" + tool = _make_tool(allow_dml=False) + tool._validate_query( + "WITH cte AS (SELECT set, reset FROM config) SELECT * FROM cte" + ) + # --------------------------------------------------------------------------- # EXPLAIN ANALYZE executes the underlying query @@ -461,7 +476,29 @@ class TestMultiStatementCommit: ): tool.execute_sql("SELECT 1; SELECT 2") - mock_session.commit.assert_not_called() + def test_writable_cte_triggers_commit(self): + """WITH d AS (DELETE ...) must trigger commit when allow_dml=True.""" + tool = _make_tool(allow_dml=True) + + mock_session = MagicMock() + mock_result = MagicMock() + mock_result.returns_rows = True + mock_result.keys.return_value = ["id"] + mock_result.fetchall.return_value = [(1,)] + mock_session.execute.return_value = mock_result + mock_session_cls = MagicMock(return_value=mock_session) + + with ( + patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"), + patch( + "crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker", + return_value=mock_session_cls, + ), + ): + tool.execute_sql( + "WITH d AS (DELETE FROM users RETURNING *) SELECT * FROM d" + ) + mock_session.commit.assert_called_once() # ---------------------------------------------------------------------------