diff --git a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py index 051f85ddb..c742c26b7 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py @@ -23,7 +23,9 @@ except ImportError: logger = logging.getLogger(__name__) # Commands allowed in read-only mode -_READ_ONLY_COMMANDS = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN", "WITH"} +# NOTE: WITH is intentionally excluded — writable CTEs start with WITH, so the +# CTE body must be inspected separately (see _validate_statement). +_READ_ONLY_COMMANDS = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"} # Commands that mutate state and are blocked by default _WRITE_COMMANDS = { @@ -67,11 +69,15 @@ class NL2SQLTool(BaseTool): """Tool that converts natural language to SQL and executes it against a database. By default the tool operates in **read-only mode**: only SELECT, SHOW, - DESCRIBE, EXPLAIN, and WITH (CTE) statements are permitted. Write + DESCRIBE, EXPLAIN, and read-only CTEs (WITH … SELECT) are permitted. Write operations (INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, TRUNCATE, …) are blocked unless ``allow_dml=True`` is set explicitly or the environment variable ``CREWAI_NL2SQL_ALLOW_DML=true`` is present. + Writable CTEs (``WITH d AS (DELETE …) SELECT …``) and + ``EXPLAIN ANALYZE `` are treated as write operations and are + blocked in read-only mode. + The ``_fetch_all_available_columns`` helper uses parameterised queries so that table names coming from the database catalogue cannot be used as an injection vector. @@ -81,7 +87,8 @@ class NL2SQLTool(BaseTool): description: str = ( "Converts natural language to SQL queries and executes them against a " "database. Read-only by default — only SELECT/SHOW/DESCRIBE/EXPLAIN " - "queries are allowed unless the tool is configured with allow_dml=True." + "queries (and read-only CTEs) are allowed unless configured with " + "allow_dml=True." ) db_uri: str = Field( title="Database URI", @@ -169,6 +176,40 @@ class NL2SQLTool(BaseTool): """Validate a single SQL statement (no semicolons).""" command = self._extract_command(stmt) + # EXPLAIN ANALYZE / EXPLAIN ANALYSE actually *executes* the underlying + # query. Resolve the real command so write operations are caught. + if command == "EXPLAIN": + tokens = stmt.strip().lstrip("(").split() + if len(tokens) >= 2 and tokens[1].upper().rstrip(";") in ( + "ANALYZE", + "ANALYSE", + ): + # The statement being explained starts at the third token. + if len(tokens) >= 3: + command = tokens[2].upper().rstrip(";") + # else: bare "EXPLAIN ANALYZE" with no query — treat as read-only. + + # WITH starts a CTE. Read-only CTEs are fine; writable CTEs + # (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode. + if command == "WITH": + tokens_upper = {t.upper().strip("();,") for t in stmt.split()} + write_found = tokens_upper & _WRITE_COMMANDS + if write_found: + found = next(iter(write_found)) + if not self.allow_dml: + raise ValueError( + f"NL2SQLTool is configured in read-only mode and blocked a " + f"writable CTE containing a '{found}' statement. To allow " + f"write operations set allow_dml=True or " + f"CREWAI_NL2SQL_ALLOW_DML=true." + ) + logger.warning( + "NL2SQLTool: executing writable CTE with '%s' because allow_dml=True.", + found, + ) + # Both read-only and writable-but-permitted CTEs need no further checks. + return + if command in _WRITE_COMMANDS: if not self.allow_dml: raise ValueError( @@ -260,7 +301,10 @@ class NL2SQLTool(BaseTool): "`pip install crewai-tools[sqlalchemy]`" ) - is_write = self._extract_command(sql_query) in _WRITE_COMMANDS + # Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a + # commit when allow_dml=True, regardless of statement order. + _stmts = [s.strip() for s in sql_query.split(";") if s.strip()] + is_write = any(self._extract_command(s) in _WRITE_COMMANDS for s in _stmts) engine = create_engine(self.db_uri) Session = sessionmaker(bind=engine) # noqa: N806 diff --git a/lib/crewai-tools/tests/tools/test_nl2sql_security.py b/lib/crewai-tools/tests/tools/test_nl2sql_security.py index 0c7a62112..60d4f2959 100644 --- a/lib/crewai-tools/tests/tools/test_nl2sql_security.py +++ b/lib/crewai-tools/tests/tools/test_nl2sql_security.py @@ -84,7 +84,7 @@ class TestReadOnlyMode: # Should not raise tool._validate_query("EXPLAIN SELECT 1") - def test_with_cte_allowed(self): + def test_read_only_cte_allowed(self): tool = _make_tool() tool._validate_query("WITH cte AS (SELECT 1) SELECT * FROM cte") @@ -327,6 +327,135 @@ class TestSemicolonInjection: tool._validate_query("DROP TABLE users") +# --------------------------------------------------------------------------- +# Writable CTEs (WITH … DELETE/INSERT/UPDATE) +# --------------------------------------------------------------------------- + + +class TestWritableCTE: + def test_writable_cte_delete_blocked_in_read_only(self): + """WITH d AS (DELETE FROM users RETURNING *) SELECT * FROM d — blocked.""" + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query( + "WITH deleted AS (DELETE FROM users RETURNING *) SELECT * FROM deleted" + ) + + def test_writable_cte_insert_blocked_in_read_only(self): + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query( + "WITH ins AS (INSERT INTO t VALUES (1) RETURNING id) SELECT * FROM ins" + ) + + def test_writable_cte_update_blocked_in_read_only(self): + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query( + "WITH upd AS (UPDATE t SET x=1 RETURNING id) SELECT * FROM upd" + ) + + def test_writable_cte_allowed_when_dml_enabled(self): + tool = _make_tool(allow_dml=True) + # Should not raise + tool._validate_query( + "WITH deleted AS (DELETE FROM users RETURNING *) SELECT * FROM deleted" + ) + + def test_plain_read_only_cte_still_allowed(self): + tool = _make_tool(allow_dml=False) + # No write commands in the CTE body — must pass + tool._validate_query("WITH cte AS (SELECT id FROM users) SELECT * FROM cte") + + +# --------------------------------------------------------------------------- +# EXPLAIN ANALYZE executes the underlying query +# --------------------------------------------------------------------------- + + +class TestExplainAnalyze: + def test_explain_analyze_delete_blocked_in_read_only(self): + """EXPLAIN ANALYZE DELETE actually runs the delete — block it.""" + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query("EXPLAIN ANALYZE DELETE FROM users") + + def test_explain_analyse_delete_blocked_in_read_only(self): + """British spelling ANALYSE is also caught.""" + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query("EXPLAIN ANALYSE DELETE FROM users") + + def test_explain_analyze_drop_blocked_in_read_only(self): + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query("EXPLAIN ANALYZE DROP TABLE users") + + def test_explain_analyze_select_allowed_in_read_only(self): + """EXPLAIN ANALYZE on a SELECT is safe — must be permitted.""" + tool = _make_tool(allow_dml=False) + tool._validate_query("EXPLAIN ANALYZE SELECT * FROM users") + + def test_explain_without_analyze_allowed(self): + tool = _make_tool(allow_dml=False) + tool._validate_query("EXPLAIN SELECT * FROM users") + + def test_explain_analyze_delete_allowed_when_dml_enabled(self): + tool = _make_tool(allow_dml=True) + tool._validate_query("EXPLAIN ANALYZE DELETE FROM users") + + +# --------------------------------------------------------------------------- +# Multi-statement commit covers ALL statements (not just the first) +# --------------------------------------------------------------------------- + + +class TestMultiStatementCommit: + def test_select_then_insert_triggers_commit(self): + """SELECT 1; INSERT … — commit must happen because INSERT is a write.""" + tool = _make_tool(allow_dml=True) + + mock_session = MagicMock() + mock_result = MagicMock() + mock_result.returns_rows = False + mock_session.execute.return_value = mock_result + mock_session_cls = MagicMock(return_value=mock_session) + + with ( + patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"), + patch( + "crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker", + return_value=mock_session_cls, + ), + ): + tool.execute_sql("SELECT 1; INSERT INTO t VALUES (1)") + + mock_session.commit.assert_called_once() + + def test_select_only_multi_statement_does_not_commit(self): + """Two SELECTs must not trigger a commit even when allow_dml=True.""" + tool = _make_tool(allow_dml=True) + + mock_session = MagicMock() + mock_result = MagicMock() + mock_result.returns_rows = True + mock_result.keys.return_value = ["v"] + mock_result.fetchall.return_value = [(1,)] + mock_session.execute.return_value = mock_result + mock_session_cls = MagicMock(return_value=mock_session) + + with ( + patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"), + patch( + "crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker", + return_value=mock_session_cls, + ), + ): + tool.execute_sql("SELECT 1; SELECT 2") + + mock_session.commit.assert_not_called() + + # --------------------------------------------------------------------------- # Extended _WRITE_COMMANDS coverage # ---------------------------------------------------------------------------