mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 14:52:36 +00:00
fix: expand _WRITE_COMMANDS and block multi-statement semicolon injection
- Add missing write commands: UPSERT, LOAD, COPY, VACUUM, ANALYZE, ANALYSE, REINDEX, CLUSTER, REFRESH, COMMENT, SET, RESET - _validate_query() now splits on ';' and validates each statement independently; multi-statement queries are rejected outright in read-only mode to prevent 'SELECT 1; DROP TABLE users' bypass - Extract single-statement logic into _validate_statement() helper - Add TestSemicolonInjection and TestExtendedWriteCommands test classes Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -35,6 +35,18 @@ _WRITE_COMMANDS = {
|
||||
"CALL",
|
||||
"MERGE",
|
||||
"REPLACE",
|
||||
"UPSERT",
|
||||
"LOAD",
|
||||
"COPY",
|
||||
"VACUUM",
|
||||
"ANALYZE",
|
||||
"ANALYSE",
|
||||
"REINDEX",
|
||||
"CLUSTER",
|
||||
"REFRESH",
|
||||
"COMMENT",
|
||||
"SET",
|
||||
"RESET",
|
||||
}
|
||||
|
||||
|
||||
@@ -127,12 +139,29 @@ class NL2SQLTool(BaseTool):
|
||||
def _validate_query(self, sql_query: str) -> None:
|
||||
"""Raise ValueError if *sql_query* is not permitted under the current config.
|
||||
|
||||
Parses the leading SQL command keyword and checks it against the
|
||||
allowed set. When ``allow_dml=False`` (the default) only read
|
||||
statements pass. When ``allow_dml=True`` all statements are allowed
|
||||
but a warning is emitted for write operations.
|
||||
Splits the query on semicolons and validates each statement
|
||||
independently. When ``allow_dml=False`` (the default), multi-statement
|
||||
queries are rejected outright to prevent ``SELECT 1; DROP TABLE users``
|
||||
style bypasses. When ``allow_dml=True`` every statement is checked and
|
||||
a warning is emitted for write operations.
|
||||
"""
|
||||
command = self._extract_command(sql_query)
|
||||
statements = [s.strip() for s in sql_query.split(";") if s.strip()]
|
||||
|
||||
if not statements:
|
||||
raise ValueError("NL2SQLTool received an empty SQL query.")
|
||||
|
||||
if not self.allow_dml and len(statements) > 1:
|
||||
raise ValueError(
|
||||
"NL2SQLTool blocked a multi-statement query in read-only mode. "
|
||||
"Semicolons are not permitted when allow_dml=False."
|
||||
)
|
||||
|
||||
for stmt in statements:
|
||||
self._validate_statement(stmt)
|
||||
|
||||
def _validate_statement(self, stmt: str) -> None:
|
||||
"""Validate a single SQL statement (no semicolons)."""
|
||||
command = self._extract_command(stmt)
|
||||
|
||||
if command in _WRITE_COMMANDS:
|
||||
if not self.allow_dml:
|
||||
|
||||
@@ -288,3 +288,69 @@ class TestRunValidation:
|
||||
tool = _make_tool(allow_dml=False)
|
||||
result = tool._run("SELECT 1 AS n")
|
||||
assert result == [{"n": 1}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multi-statement / semicolon injection prevention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSemicolonInjection:
|
||||
def test_multi_statement_blocked_in_read_only_mode(self):
|
||||
"""SELECT 1; DROP TABLE users must be rejected when allow_dml=False."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="multi-statement"):
|
||||
tool._validate_query("SELECT 1; DROP TABLE users")
|
||||
|
||||
def test_multi_statement_blocked_even_with_only_selects(self):
|
||||
"""Two SELECT statements are still rejected in read-only mode."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="multi-statement"):
|
||||
tool._validate_query("SELECT 1; SELECT 2")
|
||||
|
||||
def test_trailing_semicolon_allowed_single_statement(self):
|
||||
"""A single statement with a trailing semicolon should pass."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
# Should not raise — the part after the semicolon is empty
|
||||
tool._validate_query("SELECT 1;")
|
||||
|
||||
def test_multi_statement_allowed_when_dml_enabled(self):
|
||||
"""Multiple statements are permitted when allow_dml=True."""
|
||||
tool = _make_tool(allow_dml=True)
|
||||
# Should not raise
|
||||
tool._validate_query("SELECT 1; INSERT INTO t VALUES (1)")
|
||||
|
||||
def test_multi_statement_write_still_blocked_individually(self):
|
||||
"""Even with allow_dml=False, a single write statement is blocked."""
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query("DROP TABLE users")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended _WRITE_COMMANDS coverage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtendedWriteCommands:
|
||||
@pytest.mark.parametrize(
|
||||
"stmt",
|
||||
[
|
||||
"UPSERT INTO t VALUES (1)",
|
||||
"LOAD DATA INFILE 'f.csv' INTO TABLE t",
|
||||
"COPY t FROM '/tmp/f.csv'",
|
||||
"VACUUM ANALYZE t",
|
||||
"ANALYZE t",
|
||||
"ANALYSE t",
|
||||
"REINDEX TABLE t",
|
||||
"CLUSTER t USING idx",
|
||||
"REFRESH MATERIALIZED VIEW v",
|
||||
"COMMENT ON TABLE t IS 'desc'",
|
||||
"SET search_path = myschema",
|
||||
"RESET search_path",
|
||||
],
|
||||
)
|
||||
def test_extended_write_commands_blocked_by_default(self, stmt: str):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(stmt)
|
||||
|
||||
Reference in New Issue
Block a user