fix: harden NL2SQLTool — read-only by default, parameterized queries, query validation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Alex
2026-04-06 23:08:40 -07:00
parent 25e7ca03c4
commit 446d4e1267
2 changed files with 452 additions and 11 deletions

View File

@@ -1,7 +1,9 @@
import logging
import os
from typing import Any from typing import Any
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
try: try:
@@ -12,6 +14,29 @@ try:
except ImportError: except ImportError:
SQLALCHEMY_AVAILABLE = False SQLALCHEMY_AVAILABLE = False
logger = logging.getLogger(__name__)
# Commands allowed in read-only mode
_READ_ONLY_COMMANDS = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN", "WITH"}
# Commands that mutate state and are blocked by default
_WRITE_COMMANDS = {
"INSERT",
"UPDATE",
"DELETE",
"DROP",
"ALTER",
"CREATE",
"TRUNCATE",
"GRANT",
"REVOKE",
"EXEC",
"EXECUTE",
"CALL",
"MERGE",
"REPLACE",
}
class NL2SQLToolInput(BaseModel): class NL2SQLToolInput(BaseModel):
sql_query: str = Field( sql_query: str = Field(
@@ -21,20 +46,65 @@ class NL2SQLToolInput(BaseModel):
class NL2SQLTool(BaseTool): class NL2SQLTool(BaseTool):
"""Tool that converts natural language to SQL and executes it against a database.
By default the tool operates in **read-only mode**: only SELECT, SHOW,
DESCRIBE, EXPLAIN, and WITH (CTE) statements are permitted. Write
operations (INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, TRUNCATE, …) are
blocked unless ``allow_dml=True`` is set explicitly or the environment
variable ``CREWAI_NL2SQL_ALLOW_DML=true`` is present.
The ``_fetch_all_available_columns`` helper uses parameterised queries so
that table names coming from the database catalogue cannot be used as an
injection vector.
"""
name: str = "NL2SQLTool" name: str = "NL2SQLTool"
description: str = "Converts natural language to SQL queries and executes them." description: str = (
"Converts natural language to SQL queries and executes them against a "
"database. Read-only by default — only SELECT/SHOW/DESCRIBE/EXPLAIN "
"queries are allowed unless the tool is configured with allow_dml=True."
)
db_uri: str = Field( db_uri: str = Field(
title="Database URI", title="Database URI",
description="The URI of the database to connect to.", description="The URI of the database to connect to.",
) )
allow_dml: bool = Field(
default=False,
title="Allow DML",
description=(
"When False (default) only read statements are permitted. "
"Set to True to allow INSERT/UPDATE/DELETE/DROP and other "
"write operations."
),
)
tables: list[dict[str, Any]] = Field(default_factory=list) tables: list[dict[str, Any]] = Field(default_factory=list)
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict) columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
args_schema: type[BaseModel] = NL2SQLToolInput args_schema: type[BaseModel] = NL2SQLToolInput
@model_validator(mode="after")
def _apply_env_override(self) -> "NL2SQLTool":
"""Allow CREWAI_NL2SQL_ALLOW_DML=true to override allow_dml at runtime."""
if os.environ.get("CREWAI_NL2SQL_ALLOW_DML", "").strip().lower() == "true":
if not self.allow_dml:
logger.warning(
"NL2SQLTool: CREWAI_NL2SQL_ALLOW_DML env var is set — "
"DML/DDL operations are enabled. Ensure this is intentional."
)
self.allow_dml = True
return self
def model_post_init(self, __context: Any) -> None: def model_post_init(self, __context: Any) -> None:
if not SQLALCHEMY_AVAILABLE: if not SQLALCHEMY_AVAILABLE:
raise ImportError( raise ImportError(
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`" "sqlalchemy is not installed. Please install it with "
"`pip install crewai-tools[sqlalchemy]`"
)
if self.allow_dml:
logger.warning(
"NL2SQLTool: allow_dml=True — write operations (INSERT/UPDATE/"
"DELETE/DROP/…) are permitted. Use with caution."
) )
data: dict[str, list[dict[str, Any]] | str] = {} data: dict[str, list[dict[str, Any]] | str] = {}
@@ -50,42 +120,122 @@ class NL2SQLTool(BaseTool):
self.tables = tables self.tables = tables
self.columns = data self.columns = data
# ------------------------------------------------------------------
# Query validation
# ------------------------------------------------------------------
def _validate_query(self, sql_query: str) -> None:
"""Raise ValueError if *sql_query* is not permitted under the current config.
Parses the leading SQL command keyword and checks it against the
allowed set. When ``allow_dml=False`` (the default) only read
statements pass. When ``allow_dml=True`` all statements are allowed
but a warning is emitted for write operations.
"""
command = self._extract_command(sql_query)
if command in _WRITE_COMMANDS:
if not self.allow_dml:
raise ValueError(
f"NL2SQLTool is configured in read-only mode and blocked a "
f"'{command}' statement. To allow write operations set "
f"allow_dml=True or CREWAI_NL2SQL_ALLOW_DML=true."
)
logger.warning(
"NL2SQLTool: executing write statement '%s' because allow_dml=True.",
command,
)
elif command not in _READ_ONLY_COMMANDS:
# Unknown command — block by default unless DML is explicitly enabled
if not self.allow_dml:
raise ValueError(
f"NL2SQLTool blocked an unrecognised SQL command '{command}'. "
f"Only {sorted(_READ_ONLY_COMMANDS)} are allowed in read-only "
f"mode."
)
@staticmethod
def _extract_command(sql_query: str) -> str:
"""Return the uppercased first keyword of *sql_query*."""
stripped = sql_query.strip().lstrip("(")
first_token = stripped.split()[0] if stripped.split() else ""
return first_token.upper().rstrip(";")
# ------------------------------------------------------------------
# Schema introspection helpers
# ------------------------------------------------------------------
def _fetch_available_tables(self) -> list[dict[str, Any]] | str: def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
return self.execute_sql( return self.execute_sql(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" "SELECT table_name FROM information_schema.tables "
"WHERE table_schema = 'public';"
) )
def _fetch_all_available_columns( def _fetch_all_available_columns(
self, table_name: str self, table_name: str
) -> list[dict[str, Any]] | str: ) -> list[dict[str, Any]] | str:
"""Fetch columns for *table_name* using a parameterised query.
The table name is bound via SQLAlchemy's ``:param`` syntax to prevent
SQL injection from catalogue values.
"""
return self.execute_sql( return self.execute_sql(
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608 "SELECT column_name, data_type FROM information_schema.columns "
"WHERE table_name = :table_name",
params={"table_name": table_name},
) )
# ------------------------------------------------------------------
# Core execution
# ------------------------------------------------------------------
def _run(self, sql_query: str) -> list[dict[str, Any]] | str: def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
try: try:
self._validate_query(sql_query)
data = self.execute_sql(sql_query) data = self.execute_sql(sql_query)
except ValueError:
raise
except Exception as exc: except Exception as exc:
data = ( data = (
f"Based on these tables {self.tables} and columns {self.columns}, " f"Based on these tables {self.tables} and columns {self.columns}, "
"you can create SQL queries to retrieve data from the database." "you can create SQL queries to retrieve data from the database. "
f"Get the original request {sql_query} and the error {exc} and create the correct SQL query." f"Get the original request {sql_query} and the error {exc} and "
"create the correct SQL query."
) )
return data return data
def execute_sql(self, sql_query: str) -> list[dict[str, Any]] | str: def execute_sql(
self,
sql_query: str,
params: dict[str, Any] | None = None,
) -> list[dict[str, Any]] | str:
"""Execute *sql_query* and return the results as a list of dicts.
Parameters
----------
sql_query:
The SQL statement to run.
params:
Optional mapping of bind parameters (e.g. ``{"table_name": "users"}``).
"""
if not SQLALCHEMY_AVAILABLE: if not SQLALCHEMY_AVAILABLE:
raise ImportError( raise ImportError(
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`" "sqlalchemy is not installed. Please install it with "
"`pip install crewai-tools[sqlalchemy]`"
) )
is_write = self._extract_command(sql_query) in _WRITE_COMMANDS
engine = create_engine(self.db_uri) engine = create_engine(self.db_uri)
Session = sessionmaker(bind=engine) # noqa: N806 Session = sessionmaker(bind=engine) # noqa: N806
session = Session() session = Session()
try: try:
result = session.execute(text(sql_query)) result = session.execute(text(sql_query), params or {})
session.commit()
# Only commit when the operation actually mutates state
if self.allow_dml and is_write:
session.commit()
if result.returns_rows: # type: ignore[attr-defined] if result.returns_rows: # type: ignore[attr-defined]
columns = result.keys() columns = result.keys()

View File

@@ -0,0 +1,291 @@
"""Security tests for NL2SQLTool.
Uses an in-memory SQLite database so no external service is needed.
SQLite does not have information_schema, so we patch the schema-introspection
helpers to avoid bootstrap failures and focus purely on the security logic.
"""
import os
from unittest.mock import MagicMock, patch
import pytest
# Skip the entire module if SQLAlchemy is not installed
pytest.importorskip("sqlalchemy")
from sqlalchemy import create_engine, text # noqa: E402
from sqlalchemy.orm import sessionmaker # noqa: E402
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool # noqa: E402
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
SQLITE_URI = "sqlite://" # in-memory
def _make_tool(allow_dml: bool = False, **kwargs) -> NL2SQLTool:
"""Return a NL2SQLTool wired to an in-memory SQLite DB.
Schema-introspection is patched out so we can create the tool without a
real PostgreSQL information_schema.
"""
with (
patch.object(NL2SQLTool, "_fetch_available_tables", return_value=[]),
patch.object(NL2SQLTool, "_fetch_all_available_columns", return_value=[]),
):
return NL2SQLTool(db_uri=SQLITE_URI, allow_dml=allow_dml, **kwargs)
def _seed_db(uri: str) -> None:
"""Create a tiny table in the target database for DML tests."""
engine = create_engine(uri)
with engine.connect() as conn:
conn.execute(text("CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT)"))
conn.execute(text("INSERT INTO users VALUES (1, 'alice')"))
conn.commit()
# ---------------------------------------------------------------------------
# Read-only enforcement (allow_dml=False)
# ---------------------------------------------------------------------------
class TestReadOnlyMode:
def test_select_allowed_by_default(self):
tool = _make_tool()
# SQLite supports SELECT without information_schema
result = tool.execute_sql("SELECT 1 AS val")
assert result == [{"val": 1}]
@pytest.mark.parametrize(
"stmt",
[
"INSERT INTO t VALUES (1)",
"UPDATE t SET col = 1",
"DELETE FROM t",
"DROP TABLE t",
"ALTER TABLE t ADD col TEXT",
"CREATE TABLE t (id INTEGER)",
"TRUNCATE TABLE t",
"GRANT SELECT ON t TO user1",
"REVOKE SELECT ON t FROM user1",
"EXEC sp_something",
"EXECUTE sp_something",
"CALL proc()",
],
)
def test_write_statements_blocked_by_default(self, stmt: str):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query(stmt)
def test_explain_allowed(self):
tool = _make_tool()
# Should not raise
tool._validate_query("EXPLAIN SELECT 1")
def test_with_cte_allowed(self):
tool = _make_tool()
tool._validate_query("WITH cte AS (SELECT 1) SELECT * FROM cte")
def test_show_allowed(self):
tool = _make_tool()
tool._validate_query("SHOW TABLES")
def test_describe_allowed(self):
tool = _make_tool()
tool._validate_query("DESCRIBE users")
# ---------------------------------------------------------------------------
# DML enabled (allow_dml=True)
# ---------------------------------------------------------------------------
class TestDMLEnabled:
def test_insert_allowed_when_dml_enabled(self):
tool = _make_tool(allow_dml=True)
# Should not raise
tool._validate_query("INSERT INTO t VALUES (1)")
def test_delete_allowed_when_dml_enabled(self):
tool = _make_tool(allow_dml=True)
tool._validate_query("DELETE FROM t WHERE id = 1")
def test_drop_allowed_when_dml_enabled(self):
tool = _make_tool(allow_dml=True)
tool._validate_query("DROP TABLE t")
def test_dml_actually_persists(self):
"""End-to-end: INSERT commits when allow_dml=True."""
# Use a file-based SQLite so we can verify persistence across sessions
import tempfile, os
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
uri = f"sqlite:///{db_path}"
try:
tool = _make_tool(allow_dml=True)
tool.db_uri = uri
engine = create_engine(uri)
with engine.connect() as conn:
conn.execute(text("CREATE TABLE items (id INTEGER PRIMARY KEY)"))
conn.commit()
tool.execute_sql("INSERT INTO items VALUES (42)")
with engine.connect() as conn:
rows = conn.execute(text("SELECT id FROM items")).fetchall()
assert (42,) in rows
finally:
os.unlink(db_path)
# ---------------------------------------------------------------------------
# Parameterised query — SQL injection prevention
# ---------------------------------------------------------------------------
class TestParameterisedQueries:
def test_table_name_is_parameterised(self):
"""_fetch_all_available_columns must not interpolate table_name into SQL."""
tool = _make_tool()
captured_calls = []
def recording_execute_sql(self_inner, sql_query, params=None):
captured_calls.append((sql_query, params))
return []
with patch.object(NL2SQLTool, "execute_sql", recording_execute_sql):
tool._fetch_all_available_columns("users'; DROP TABLE users; --")
assert len(captured_calls) == 1
sql, params = captured_calls[0]
# The raw SQL must NOT contain the injected string
assert "DROP" not in sql
# The table name must be passed as a parameter
assert params is not None
assert params.get("table_name") == "users'; DROP TABLE users; --"
# The SQL template must use the :param syntax
assert ":table_name" in sql
def test_injection_string_not_in_sql_template(self):
"""The f-string vulnerability is gone — table name never lands in the SQL."""
tool = _make_tool()
injection = "'; DROP TABLE users; --"
captured = {}
def spy(self_inner, sql_query, params=None):
captured["sql"] = sql_query
captured["params"] = params
return []
with patch.object(NL2SQLTool, "execute_sql", spy):
tool._fetch_all_available_columns(injection)
assert injection not in captured["sql"]
assert captured["params"]["table_name"] == injection
# ---------------------------------------------------------------------------
# session.commit() not called for read-only queries
# ---------------------------------------------------------------------------
class TestNoCommitForReadOnly:
def test_select_does_not_commit(self):
tool = _make_tool(allow_dml=False)
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.returns_rows = True
mock_result.keys.return_value = ["val"]
mock_result.fetchall.return_value = [(1,)]
mock_session.execute.return_value = mock_result
mock_session_cls = MagicMock(return_value=mock_session)
with (
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
patch(
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
return_value=mock_session_cls,
),
):
tool.execute_sql("SELECT 1")
mock_session.commit.assert_not_called()
def test_write_with_dml_enabled_does_commit(self):
tool = _make_tool(allow_dml=True)
mock_session = MagicMock()
mock_result = MagicMock()
mock_result.returns_rows = False
mock_session.execute.return_value = mock_result
mock_session_cls = MagicMock(return_value=mock_session)
with (
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
patch(
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
return_value=mock_session_cls,
),
):
tool.execute_sql("INSERT INTO t VALUES (1)")
mock_session.commit.assert_called_once()
# ---------------------------------------------------------------------------
# Environment-variable escape hatch
# ---------------------------------------------------------------------------
class TestEnvVarEscapeHatch:
def test_env_var_enables_dml(self):
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "true"}):
tool = _make_tool(allow_dml=False)
assert tool.allow_dml is True
def test_env_var_case_insensitive(self):
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "TRUE"}):
tool = _make_tool(allow_dml=False)
assert tool.allow_dml is True
def test_env_var_absent_keeps_default(self):
env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"}
with patch.dict(os.environ, env, clear=True):
tool = _make_tool(allow_dml=False)
assert tool.allow_dml is False
def test_env_var_false_does_not_enable_dml(self):
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "false"}):
tool = _make_tool(allow_dml=False)
assert tool.allow_dml is False
def test_dml_write_blocked_without_env_var(self):
env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"}
with patch.dict(os.environ, env, clear=True):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._validate_query("DROP TABLE sensitive_data")
# ---------------------------------------------------------------------------
# _run() propagates ValueError from _validate_query
# ---------------------------------------------------------------------------
class TestRunValidation:
def test_run_raises_on_blocked_query(self):
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="read-only mode"):
tool._run("DELETE FROM users")
def test_run_returns_results_for_select(self):
tool = _make_tool(allow_dml=False)
result = tool._run("SELECT 1 AS n")
assert result == [{"n": 1}]