diff --git a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py index 2566b7825..436fb5471 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py @@ -35,6 +35,18 @@ _WRITE_COMMANDS = { "CALL", "MERGE", "REPLACE", + "UPSERT", + "LOAD", + "COPY", + "VACUUM", + "ANALYZE", + "ANALYSE", + "REINDEX", + "CLUSTER", + "REFRESH", + "COMMENT", + "SET", + "RESET", } @@ -127,12 +139,29 @@ class NL2SQLTool(BaseTool): def _validate_query(self, sql_query: str) -> None: """Raise ValueError if *sql_query* is not permitted under the current config. - Parses the leading SQL command keyword and checks it against the - allowed set. When ``allow_dml=False`` (the default) only read - statements pass. When ``allow_dml=True`` all statements are allowed - but a warning is emitted for write operations. + Splits the query on semicolons and validates each statement + independently. When ``allow_dml=False`` (the default), multi-statement + queries are rejected outright to prevent ``SELECT 1; DROP TABLE users`` + style bypasses. When ``allow_dml=True`` every statement is checked and + a warning is emitted for write operations. """ - command = self._extract_command(sql_query) + statements = [s.strip() for s in sql_query.split(";") if s.strip()] + + if not statements: + raise ValueError("NL2SQLTool received an empty SQL query.") + + if not self.allow_dml and len(statements) > 1: + raise ValueError( + "NL2SQLTool blocked a multi-statement query in read-only mode. " + "Semicolons are not permitted when allow_dml=False." + ) + + for stmt in statements: + self._validate_statement(stmt) + + def _validate_statement(self, stmt: str) -> None: + """Validate a single SQL statement (no semicolons).""" + command = self._extract_command(stmt) if command in _WRITE_COMMANDS: if not self.allow_dml: diff --git a/lib/crewai-tools/tests/tools/test_nl2sql_security.py b/lib/crewai-tools/tests/tools/test_nl2sql_security.py index 91838c2ab..0c7a62112 100644 --- a/lib/crewai-tools/tests/tools/test_nl2sql_security.py +++ b/lib/crewai-tools/tests/tools/test_nl2sql_security.py @@ -288,3 +288,69 @@ class TestRunValidation: tool = _make_tool(allow_dml=False) result = tool._run("SELECT 1 AS n") assert result == [{"n": 1}] + + +# --------------------------------------------------------------------------- +# Multi-statement / semicolon injection prevention +# --------------------------------------------------------------------------- + + +class TestSemicolonInjection: + def test_multi_statement_blocked_in_read_only_mode(self): + """SELECT 1; DROP TABLE users must be rejected when allow_dml=False.""" + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="multi-statement"): + tool._validate_query("SELECT 1; DROP TABLE users") + + def test_multi_statement_blocked_even_with_only_selects(self): + """Two SELECT statements are still rejected in read-only mode.""" + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="multi-statement"): + tool._validate_query("SELECT 1; SELECT 2") + + def test_trailing_semicolon_allowed_single_statement(self): + """A single statement with a trailing semicolon should pass.""" + tool = _make_tool(allow_dml=False) + # Should not raise — the part after the semicolon is empty + tool._validate_query("SELECT 1;") + + def test_multi_statement_allowed_when_dml_enabled(self): + """Multiple statements are permitted when allow_dml=True.""" + tool = _make_tool(allow_dml=True) + # Should not raise + tool._validate_query("SELECT 1; INSERT INTO t VALUES (1)") + + def test_multi_statement_write_still_blocked_individually(self): + """Even with allow_dml=False, a single write statement is blocked.""" + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query("DROP TABLE users") + + +# --------------------------------------------------------------------------- +# Extended _WRITE_COMMANDS coverage +# --------------------------------------------------------------------------- + + +class TestExtendedWriteCommands: + @pytest.mark.parametrize( + "stmt", + [ + "UPSERT INTO t VALUES (1)", + "LOAD DATA INFILE 'f.csv' INTO TABLE t", + "COPY t FROM '/tmp/f.csv'", + "VACUUM ANALYZE t", + "ANALYZE t", + "ANALYSE t", + "REINDEX TABLE t", + "CLUSTER t USING idx", + "REFRESH MATERIALIZED VIEW v", + "COMMENT ON TABLE t IS 'desc'", + "SET search_path = myschema", + "RESET search_path", + ], + ) + def test_extended_write_commands_blocked_by_default(self, stmt: str): + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query(stmt)