Files
crewAI/lib/crewai-tools/tests/tools/test_nl2sql_security.py
Alex fec180ea5c fix: handle parenthesized EXPLAIN options syntax; remove unused _seed_db
_validate_statement now strips parenthesized options from EXPLAIN (e.g.
EXPLAIN (ANALYZE) DELETE, EXPLAIN (ANALYZE, VERBOSE) DELETE) before
checking whether ANALYZE/ANALYSE is present — closing the bypass where
the options-list form was silently allowed in read-only mode.

Adds three new tests:
  - EXPLAIN (ANALYZE) DELETE  → blocked
  - EXPLAIN (ANALYZE, VERBOSE) DELETE  → blocked
  - EXPLAIN (VERBOSE) SELECT  → allowed

Also removes the unused _seed_db helper from test_nl2sql_security.py.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 09:17:34 -07:00

494 lines
19 KiB
Python

"""Security tests for NL2SQLTool.
Uses an in-memory SQLite database so no external service is needed.
SQLite does not have information_schema, so we patch the schema-introspection
helpers to avoid bootstrap failures and focus purely on the security logic.
"""
import os
from unittest.mock import MagicMock, patch
import pytest
# Skip the entire module if SQLAlchemy is not installed
pytest.importorskip("sqlalchemy")
from sqlalchemy import create_engine, text # noqa: E402
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool # noqa: E402
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
SQLITE_URI = "sqlite://" # in-memory
def _make_tool(allow_dml: bool = False, **kwargs) -> NL2SQLTool:
"""Return a NL2SQLTool wired to an in-memory SQLite DB.
Schema-introspection is patched out so we can create the tool without a
real PostgreSQL information_schema.
"""
with (
patch.object(NL2SQLTool, "_fetch_available_tables", return_value=[]),
patch.object(NL2SQLTool, "_fetch_all_available_columns", return_value=[]),
):
return NL2SQLTool(db_uri=SQLITE_URI, allow_dml=allow_dml, **kwargs)
# ---------------------------------------------------------------------------
# Read-only enforcement (allow_dml=False)
# ---------------------------------------------------------------------------
class TestReadOnlyMode:
def test_select_allowed_by_default(self):
tool = _make_tool()
# SQLite supports SELECT without information_schema
result = tool.execute_sql("SELECT 1 AS val")
assert result == [{"val": 1}]
@pytest.mark.parametrize(
"stmt",
[
"INSERT INTO t VALUES (1)",
"UPDATE t SET col = 1",
"DELETE FROM t",
"DROP TABLE t",
"ALTER TABLE t ADD col TEXT",
"CREATE TABLE t (id INTEGER)",
"TRUNCATE TABLE t",
"GRANT SELECT ON t TO user1",
"REVOKE SELECT ON t FROM user1",
"EXEC sp_something",
"EXECUTE sp_something",
"CALL proc()",
],
)
def test_write_statements_blocked_by_default(self, stmt: str):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query(stmt)
def test_explain_allowed(self):
tool = _make_tool()
# Should not raise
tool._validate_query("EXPLAIN SELECT 1")
def test_read_only_cte_allowed(self):
tool = _make_tool()
tool._validate_query("WITH cte AS (SELECT 1) SELECT * FROM cte")
def test_show_allowed(self):
tool = _make_tool()
tool._validate_query("SHOW TABLES")
def test_describe_allowed(self):
tool = _make_tool()
tool._validate_query("DESCRIBE users")
# ---------------------------------------------------------------------------
# DML enabled (allow_dml=True)
# ---------------------------------------------------------------------------
class TestDMLEnabled:
def test_insert_allowed_when_dml_enabled(self):
tool = _make_tool(allow_dml=True)
# Should not raise
tool._validate_query("INSERT INTO t VALUES (1)")
def test_delete_allowed_when_dml_enabled(self):
tool = _make_tool(allow_dml=True)
tool._validate_query("DELETE FROM t WHERE id = 1")
def test_drop_allowed_when_dml_enabled(self):
tool = _make_tool(allow_dml=True)
tool._validate_query("DROP TABLE t")
def test_dml_actually_persists(self):
"""End-to-end: INSERT commits when allow_dml=True."""
# Use a file-based SQLite so we can verify persistence across sessions
import tempfile, os
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
uri = f"sqlite:///{db_path}"
try:
tool = _make_tool(allow_dml=True)
tool.db_uri = uri
engine = create_engine(uri)
with engine.connect() as conn:
conn.execute(text("CREATE TABLE items (id INTEGER PRIMARY KEY)"))
conn.commit()
tool.execute_sql("INSERT INTO items VALUES (42)")
with engine.connect() as conn:
rows = conn.execute(text("SELECT id FROM items")).fetchall()
assert (42,) in rows
finally:
os.unlink(db_path)
# ---------------------------------------------------------------------------
# Parameterised query — SQL injection prevention
# ---------------------------------------------------------------------------
class TestParameterisedQueries:
def test_table_name_is_parameterised(self):
"""_fetch_all_available_columns must not interpolate table_name into SQL."""
tool = _make_tool()
captured_calls = []
def recording_execute_sql(self_inner, sql_query, params=None):
captured_calls.append((sql_query, params))
return []
with patch.object(NL2SQLTool, "execute_sql", recording_execute_sql):
tool._fetch_all_available_columns("users'; DROP TABLE users; --")
assert len(captured_calls) == 1
sql, params = captured_calls[0]
# The raw SQL must NOT contain the injected string
assert "DROP" not in sql
# The table name must be passed as a parameter
assert params is not None
assert params.get("table_name") == "users'; DROP TABLE users; --"
# The SQL template must use the :param syntax
assert ":table_name" in sql
def test_injection_string_not_in_sql_template(self):
"""The f-string vulnerability is gone — table name never lands in the SQL."""
tool = _make_tool()
injection = "'; DROP TABLE users; --"
captured = {}
def spy(self_inner, sql_query, params=None):
captured["sql"] = sql_query
captured["params"] = params
return []
with patch.object(NL2SQLTool, "execute_sql", spy):
tool._fetch_all_available_columns(injection)
assert injection not in captured["sql"]
assert captured["params"]["table_name"] == injection
# ---------------------------------------------------------------------------
# session.commit() not called for read-only queries
# ---------------------------------------------------------------------------
class TestNoCommitForReadOnly:
def test_select_does_not_commit(self):
tool = _make_tool(allow_dml=False)
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.returns_rows = True
mock_result.keys.return_value = ["val"]
mock_result.fetchall.return_value = [(1,)]
mock_session.execute.return_value = mock_result
mock_session_cls = MagicMock(return_value=mock_session)
with (
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
patch(
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
return_value=mock_session_cls,
),
):
tool.execute_sql("SELECT 1")
mock_session.commit.assert_not_called()
def test_write_with_dml_enabled_does_commit(self):
tool = _make_tool(allow_dml=True)
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.returns_rows = False
mock_session.execute.return_value = mock_result
mock_session_cls = MagicMock(return_value=mock_session)
with (
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
patch(
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
return_value=mock_session_cls,
),
):
tool.execute_sql("INSERT INTO t VALUES (1)")
mock_session.commit.assert_called_once()
# ---------------------------------------------------------------------------
# Environment-variable escape hatch
# ---------------------------------------------------------------------------
class TestEnvVarEscapeHatch:
def test_env_var_enables_dml(self):
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "true"}):
tool = _make_tool(allow_dml=False)
assert tool.allow_dml is True
def test_env_var_case_insensitive(self):
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "TRUE"}):
tool = _make_tool(allow_dml=False)
assert tool.allow_dml is True
def test_env_var_absent_keeps_default(self):
env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"}
with patch.dict(os.environ, env, clear=True):
tool = _make_tool(allow_dml=False)
assert tool.allow_dml is False
def test_env_var_false_does_not_enable_dml(self):
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "false"}):
tool = _make_tool(allow_dml=False)
assert tool.allow_dml is False
def test_dml_write_blocked_without_env_var(self):
env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"}
with patch.dict(os.environ, env, clear=True):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query("DROP TABLE sensitive_data")
# ---------------------------------------------------------------------------
# _run() propagates ValueError from _validate_query
# ---------------------------------------------------------------------------
class TestRunValidation:
def test_run_raises_on_blocked_query(self):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._run("DELETE FROM users")
def test_run_returns_results_for_select(self):
tool = _make_tool(allow_dml=False)
result = tool._run("SELECT 1 AS n")
assert result == [{"n": 1}]
# ---------------------------------------------------------------------------
# Multi-statement / semicolon injection prevention
# ---------------------------------------------------------------------------
class TestSemicolonInjection:
def test_multi_statement_blocked_in_read_only_mode(self):
"""SELECT 1; DROP TABLE users must be rejected when allow_dml=False."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="multi-statement"):
tool._validate_query("SELECT 1; DROP TABLE users")
def test_multi_statement_blocked_even_with_only_selects(self):
"""Two SELECT statements are still rejected in read-only mode."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="multi-statement"):
tool._validate_query("SELECT 1; SELECT 2")
def test_trailing_semicolon_allowed_single_statement(self):
"""A single statement with a trailing semicolon should pass."""
tool = _make_tool(allow_dml=False)
# Should not raise — the part after the semicolon is empty
tool._validate_query("SELECT 1;")
def test_multi_statement_allowed_when_dml_enabled(self):
"""Multiple statements are permitted when allow_dml=True."""
tool = _make_tool(allow_dml=True)
# Should not raise
tool._validate_query("SELECT 1; INSERT INTO t VALUES (1)")
def test_multi_statement_write_still_blocked_individually(self):
"""Even with allow_dml=False, a single write statement is blocked."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query("DROP TABLE users")
# ---------------------------------------------------------------------------
# Writable CTEs (WITH … DELETE/INSERT/UPDATE)
# ---------------------------------------------------------------------------
class TestWritableCTE:
def test_writable_cte_delete_blocked_in_read_only(self):
"""WITH d AS (DELETE FROM users RETURNING *) SELECT * FROM d — blocked."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query(
"WITH deleted AS (DELETE FROM users RETURNING *) SELECT * FROM deleted"
)
def test_writable_cte_insert_blocked_in_read_only(self):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query(
"WITH ins AS (INSERT INTO t VALUES (1) RETURNING id) SELECT * FROM ins"
)
def test_writable_cte_update_blocked_in_read_only(self):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query(
"WITH upd AS (UPDATE t SET x=1 RETURNING id) SELECT * FROM upd"
)
def test_writable_cte_allowed_when_dml_enabled(self):
tool = _make_tool(allow_dml=True)
# Should not raise
tool._validate_query(
"WITH deleted AS (DELETE FROM users RETURNING *) SELECT * FROM deleted"
)
def test_plain_read_only_cte_still_allowed(self):
tool = _make_tool(allow_dml=False)
# No write commands in the CTE body — must pass
tool._validate_query("WITH cte AS (SELECT id FROM users) SELECT * FROM cte")
# ---------------------------------------------------------------------------
# EXPLAIN ANALYZE executes the underlying query
# ---------------------------------------------------------------------------
class TestExplainAnalyze:
def test_explain_analyze_delete_blocked_in_read_only(self):
"""EXPLAIN ANALYZE DELETE actually runs the delete — block it."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query("EXPLAIN ANALYZE DELETE FROM users")
def test_explain_analyse_delete_blocked_in_read_only(self):
"""British spelling ANALYSE is also caught."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query("EXPLAIN ANALYSE DELETE FROM users")
def test_explain_analyze_drop_blocked_in_read_only(self):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query("EXPLAIN ANALYZE DROP TABLE users")
def test_explain_analyze_select_allowed_in_read_only(self):
"""EXPLAIN ANALYZE on a SELECT is safe — must be permitted."""
tool = _make_tool(allow_dml=False)
tool._validate_query("EXPLAIN ANALYZE SELECT * FROM users")
def test_explain_without_analyze_allowed(self):
tool = _make_tool(allow_dml=False)
tool._validate_query("EXPLAIN SELECT * FROM users")
def test_explain_analyze_delete_allowed_when_dml_enabled(self):
tool = _make_tool(allow_dml=True)
tool._validate_query("EXPLAIN ANALYZE DELETE FROM users")
def test_explain_paren_analyze_delete_blocked_in_read_only(self):
"""EXPLAIN (ANALYZE) DELETE actually runs the delete — block it."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query("EXPLAIN (ANALYZE) DELETE FROM users")
def test_explain_paren_analyze_verbose_delete_blocked_in_read_only(self):
"""EXPLAIN (ANALYZE, VERBOSE) DELETE actually runs the delete — block it."""
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query("EXPLAIN (ANALYZE, VERBOSE) DELETE FROM users")
def test_explain_paren_verbose_select_allowed_in_read_only(self):
"""EXPLAIN (VERBOSE) SELECT is safe — no ANALYZE means no execution."""
tool = _make_tool(allow_dml=False)
tool._validate_query("EXPLAIN (VERBOSE) SELECT * FROM users")
# ---------------------------------------------------------------------------
# Multi-statement commit covers ALL statements (not just the first)
# ---------------------------------------------------------------------------
class TestMultiStatementCommit:
def test_select_then_insert_triggers_commit(self):
"""SELECT 1; INSERT … — commit must happen because INSERT is a write."""
tool = _make_tool(allow_dml=True)
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.returns_rows = False
mock_session.execute.return_value = mock_result
mock_session_cls = MagicMock(return_value=mock_session)
with (
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
patch(
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
return_value=mock_session_cls,
),
):
tool.execute_sql("SELECT 1; INSERT INTO t VALUES (1)")
mock_session.commit.assert_called_once()
def test_select_only_multi_statement_does_not_commit(self):
"""Two SELECTs must not trigger a commit even when allow_dml=True."""
tool = _make_tool(allow_dml=True)
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.returns_rows = True
mock_result.keys.return_value = ["v"]
mock_result.fetchall.return_value = [(1,)]
mock_session.execute.return_value = mock_result
mock_session_cls = MagicMock(return_value=mock_session)
with (
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
patch(
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
return_value=mock_session_cls,
),
):
tool.execute_sql("SELECT 1; SELECT 2")
mock_session.commit.assert_not_called()
# ---------------------------------------------------------------------------
# Extended _WRITE_COMMANDS coverage
# ---------------------------------------------------------------------------
class TestExtendedWriteCommands:
@pytest.mark.parametrize(
"stmt",
[
"UPSERT INTO t VALUES (1)",
"LOAD DATA INFILE 'f.csv' INTO TABLE t",
"COPY t FROM '/tmp/f.csv'",
"VACUUM ANALYZE t",
"ANALYZE t",
"ANALYSE t",
"REINDEX TABLE t",
"CLUSTER t USING idx",
"REFRESH MATERIALIZED VIEW v",
"COMMENT ON TABLE t IS 'desc'",
"SET search_path = myschema",
"RESET search_path",
],
)
def test_extended_write_commands_blocked_by_default(self, stmt: str):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query(stmt)