fix: expand _WRITE_COMMANDS and block multi-statement semicolon injection

- Add missing write commands: UPSERT, LOAD, COPY, VACUUM, ANALYZE,
  ANALYSE, REINDEX, CLUSTER, REFRESH, COMMENT, SET, RESET
- _validate_query() now splits on ';' and validates each statement
  independently; multi-statement queries are rejected outright in
  read-only mode to prevent 'SELECT 1; DROP TABLE users' bypass
- Extract single-statement logic into _validate_statement() helper
- Add TestSemicolonInjection and TestExtendedWriteCommands test classes

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Alex
2026-04-06 23:29:25 -07:00
parent 84014abe03
commit c8bb781604
2 changed files with 100 additions and 5 deletions

View File

@@ -35,6 +35,18 @@ _WRITE_COMMANDS = {
"CALL",
"MERGE",
"REPLACE",
"UPSERT",
"LOAD",
"COPY",
"VACUUM",
"ANALYZE",
"ANALYSE",
"REINDEX",
"CLUSTER",
"REFRESH",
"COMMENT",
"SET",
"RESET",
}
@@ -127,12 +139,29 @@ class NL2SQLTool(BaseTool):
def _validate_query(self, sql_query: str) -> None:
"""Raise ValueError if *sql_query* is not permitted under the current config.
Parses the leading SQL command keyword and checks it against the
allowed set. When ``allow_dml=False`` (the default) only read
statements pass. When ``allow_dml=True`` all statements are allowed
but a warning is emitted for write operations.
Splits the query on semicolons and validates each statement
independently. When ``allow_dml=False`` (the default), multi-statement
queries are rejected outright to prevent ``SELECT 1; DROP TABLE users``
style bypasses. When ``allow_dml=True`` every statement is checked and
a warning is emitted for write operations.
"""
command = self._extract_command(sql_query)
statements = [s.strip() for s in sql_query.split(";") if s.strip()]
if not statements:
raise ValueError("NL2SQLTool received an empty SQL query.")
if not self.allow_dml and len(statements) > 1:
raise ValueError(
"NL2SQLTool blocked a multi-statement query in read-only mode. "
"Semicolons are not permitted when allow_dml=False."
)
for stmt in statements:
self._validate_statement(stmt)
def _validate_statement(self, stmt: str) -> None:
"""Validate a single SQL statement (no semicolons)."""
command = self._extract_command(stmt)
if command in _WRITE_COMMANDS:
if not self.allow_dml:

View File

@@ -288,3 +288,69 @@ class TestRunValidation:
tool = _make_tool(allow_dml=False)
result = tool._run("SELECT 1 AS n")
assert result == [{"n": 1}]
# ---------------------------------------------------------------------------
# Multi-statement / semicolon injection prevention
# ---------------------------------------------------------------------------
class TestSemicolonInjection:
def test_multi_statement_blocked_in_read_only_mode(self):
"""SELECT 1; DROP TABLE users must be rejected when allow_dml=False."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="multi-statement"):
tool._validate_query("SELECT 1; DROP TABLE users")
def test_multi_statement_blocked_even_with_only_selects(self):
"""Two SELECT statements are still rejected in read-only mode."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="multi-statement"):
tool._validate_query("SELECT 1; SELECT 2")
def test_trailing_semicolon_allowed_single_statement(self):
"""A single statement with a trailing semicolon should pass."""
tool = _make_tool(allow_dml=False)
# Should not raise — the part after the semicolon is empty
tool._validate_query("SELECT 1;")
def test_multi_statement_allowed_when_dml_enabled(self):
"""Multiple statements are permitted when allow_dml=True."""
tool = _make_tool(allow_dml=True)
# Should not raise
tool._validate_query("SELECT 1; INSERT INTO t VALUES (1)")
def test_multi_statement_write_still_blocked_individually(self):
"""Even with allow_dml=False, a single write statement is blocked."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query("DROP TABLE users")
# ---------------------------------------------------------------------------
# Extended _WRITE_COMMANDS coverage
# ---------------------------------------------------------------------------
class TestExtendedWriteCommands:
@pytest.mark.parametrize(
"stmt",
[
"UPSERT INTO t VALUES (1)",
"LOAD DATA INFILE 'f.csv' INTO TABLE t",
"COPY t FROM '/tmp/f.csv'",
"VACUUM ANALYZE t",
"ANALYZE t",
"ANALYSE t",
"REINDEX TABLE t",
"CLUSTER t USING idx",
"REFRESH MATERIALIZED VIEW v",
"COMMENT ON TABLE t IS 'desc'",
"SET search_path = myschema",
"RESET search_path",
],
)
def test_extended_write_commands_blocked_by_default(self, stmt: str):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query(stmt)