mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: harden NL2SQLTool — read-only by default, parameterized queries, query validation
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
291
lib/crewai-tools/tests/tools/test_nl2sql_security.py
Normal file
291
lib/crewai-tools/tests/tools/test_nl2sql_security.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Security tests for NL2SQLTool.
|
||||
|
||||
Uses an in-memory SQLite database so no external service is needed.
|
||||
SQLite does not have information_schema, so we patch the schema-introspection
|
||||
helpers to avoid bootstrap failures and focus purely on the security logic.
|
||||
"""
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Skip the entire module if SQLAlchemy is not installed
|
||||
pytest.importorskip("sqlalchemy")
|
||||
|
||||
from sqlalchemy import create_engine, text # noqa: E402
|
||||
from sqlalchemy.orm import sessionmaker # noqa: E402
|
||||
|
||||
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool # noqa: E402
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_URI = "sqlite://" # in-memory
|
||||
|
||||
|
||||
def _make_tool(allow_dml: bool = False, **kwargs) -> NL2SQLTool:
|
||||
"""Return a NL2SQLTool wired to an in-memory SQLite DB.
|
||||
|
||||
Schema-introspection is patched out so we can create the tool without a
|
||||
real PostgreSQL information_schema.
|
||||
"""
|
||||
with (
|
||||
patch.object(NL2SQLTool, "_fetch_available_tables", return_value=[]),
|
||||
patch.object(NL2SQLTool, "_fetch_all_available_columns", return_value=[]),
|
||||
):
|
||||
return NL2SQLTool(db_uri=SQLITE_URI, allow_dml=allow_dml, **kwargs)
|
||||
|
||||
|
||||
def _seed_db(uri: str) -> None:
|
||||
"""Create a tiny table in the target database for DML tests."""
|
||||
engine = create_engine(uri)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT)"))
|
||||
conn.execute(text("INSERT INTO users VALUES (1, 'alice')"))
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read-only enforcement (allow_dml=False)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadOnlyMode:
|
||||
def test_select_allowed_by_default(self):
|
||||
tool = _make_tool()
|
||||
# SQLite supports SELECT without information_schema
|
||||
result = tool.execute_sql("SELECT 1 AS val")
|
||||
assert result == [{"val": 1}]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stmt",
|
||||
[
|
||||
"INSERT INTO t VALUES (1)",
|
||||
"UPDATE t SET col = 1",
|
||||
"DELETE FROM t",
|
||||
"DROP TABLE t",
|
||||
"ALTER TABLE t ADD col TEXT",
|
||||
"CREATE TABLE t (id INTEGER)",
|
||||
"TRUNCATE TABLE t",
|
||||
"GRANT SELECT ON t TO user1",
|
||||
"REVOKE SELECT ON t FROM user1",
|
||||
"EXEC sp_something",
|
||||
"EXECUTE sp_something",
|
||||
"CALL proc()",
|
||||
],
|
||||
)
|
||||
def test_write_statements_blocked_by_default(self, stmt: str):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query(stmt)
|
||||
|
||||
def test_explain_allowed(self):
|
||||
tool = _make_tool()
|
||||
# Should not raise
|
||||
tool._validate_query("EXPLAIN SELECT 1")
|
||||
|
||||
def test_with_cte_allowed(self):
|
||||
tool = _make_tool()
|
||||
tool._validate_query("WITH cte AS (SELECT 1) SELECT * FROM cte")
|
||||
|
||||
def test_show_allowed(self):
|
||||
tool = _make_tool()
|
||||
tool._validate_query("SHOW TABLES")
|
||||
|
||||
def test_describe_allowed(self):
|
||||
tool = _make_tool()
|
||||
tool._validate_query("DESCRIBE users")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DML enabled (allow_dml=True)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDMLEnabled:
|
||||
def test_insert_allowed_when_dml_enabled(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
# Should not raise
|
||||
tool._validate_query("INSERT INTO t VALUES (1)")
|
||||
|
||||
def test_delete_allowed_when_dml_enabled(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
tool._validate_query("DELETE FROM t WHERE id = 1")
|
||||
|
||||
def test_drop_allowed_when_dml_enabled(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
tool._validate_query("DROP TABLE t")
|
||||
|
||||
def test_dml_actually_persists(self):
|
||||
"""End-to-end: INSERT commits when allow_dml=True."""
|
||||
# Use a file-based SQLite so we can verify persistence across sessions
|
||||
import tempfile, os
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
uri = f"sqlite:///{db_path}"
|
||||
try:
|
||||
tool = _make_tool(allow_dml=True)
|
||||
tool.db_uri = uri
|
||||
|
||||
engine = create_engine(uri)
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text("CREATE TABLE items (id INTEGER PRIMARY KEY)"))
|
||||
conn.commit()
|
||||
|
||||
tool.execute_sql("INSERT INTO items VALUES (42)")
|
||||
|
||||
with engine.connect() as conn:
|
||||
rows = conn.execute(text("SELECT id FROM items")).fetchall()
|
||||
assert (42,) in rows
|
||||
finally:
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parameterised query — SQL injection prevention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParameterisedQueries:
|
||||
def test_table_name_is_parameterised(self):
|
||||
"""_fetch_all_available_columns must not interpolate table_name into SQL."""
|
||||
tool = _make_tool()
|
||||
captured_calls = []
|
||||
|
||||
def recording_execute_sql(self_inner, sql_query, params=None):
|
||||
captured_calls.append((sql_query, params))
|
||||
return []
|
||||
|
||||
with patch.object(NL2SQLTool, "execute_sql", recording_execute_sql):
|
||||
tool._fetch_all_available_columns("users'; DROP TABLE users; --")
|
||||
|
||||
assert len(captured_calls) == 1
|
||||
sql, params = captured_calls[0]
|
||||
# The raw SQL must NOT contain the injected string
|
||||
assert "DROP" not in sql
|
||||
# The table name must be passed as a parameter
|
||||
assert params is not None
|
||||
assert params.get("table_name") == "users'; DROP TABLE users; --"
|
||||
# The SQL template must use the :param syntax
|
||||
assert ":table_name" in sql
|
||||
|
||||
def test_injection_string_not_in_sql_template(self):
|
||||
"""The f-string vulnerability is gone — table name never lands in the SQL."""
|
||||
tool = _make_tool()
|
||||
injection = "'; DROP TABLE users; --"
|
||||
captured = {}
|
||||
|
||||
def spy(self_inner, sql_query, params=None):
|
||||
captured["sql"] = sql_query
|
||||
captured["params"] = params
|
||||
return []
|
||||
|
||||
with patch.object(NL2SQLTool, "execute_sql", spy):
|
||||
tool._fetch_all_available_columns(injection)
|
||||
|
||||
assert injection not in captured["sql"]
|
||||
assert captured["params"]["table_name"] == injection
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session.commit() not called for read-only queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoCommitForReadOnly:
|
||||
def test_select_does_not_commit(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returns_rows = True
|
||||
mock_result.keys.return_value = ["val"]
|
||||
mock_result.fetchall.return_value = [(1,)]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
mock_session_cls = MagicMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||
patch(
|
||||
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||
return_value=mock_session_cls,
|
||||
),
|
||||
):
|
||||
tool.execute_sql("SELECT 1")
|
||||
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
def test_write_with_dml_enabled_does_commit(self):
|
||||
tool = _make_tool(allow_dml=True)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.returns_rows = False
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
mock_session_cls = MagicMock(return_value=mock_session)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||
patch(
|
||||
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||
return_value=mock_session_cls,
|
||||
),
|
||||
):
|
||||
tool.execute_sql("INSERT INTO t VALUES (1)")
|
||||
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment-variable escape hatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnvVarEscapeHatch:
|
||||
def test_env_var_enables_dml(self):
|
||||
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "true"}):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
assert tool.allow_dml is True
|
||||
|
||||
def test_env_var_case_insensitive(self):
|
||||
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "TRUE"}):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
assert tool.allow_dml is True
|
||||
|
||||
def test_env_var_absent_keeps_default(self):
|
||||
env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"}
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
assert tool.allow_dml is False
|
||||
|
||||
def test_env_var_false_does_not_enable_dml(self):
|
||||
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "false"}):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
assert tool.allow_dml is False
|
||||
|
||||
def test_dml_write_blocked_without_env_var(self):
|
||||
env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"}
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._validate_query("DROP TABLE sensitive_data")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run() propagates ValueError from _validate_query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunValidation:
|
||||
def test_run_raises_on_blocked_query(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._run("DELETE FROM users")
|
||||
|
||||
def test_run_returns_results_for_select(self):
|
||||
tool = _make_tool(allow_dml=False)
|
||||
result = tool._run("SELECT 1 AS n")
|
||||
assert result == [{"n": 1}]
|
||||
Reference in New Issue
Block a user