mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: harden NL2SQLTool — read-only by default, parameterized queries, query validation
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -12,6 +14,29 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
SQLALCHEMY_AVAILABLE = False
|
SQLALCHEMY_AVAILABLE = False
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Commands allowed in read-only mode
|
||||||
|
_READ_ONLY_COMMANDS = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN", "WITH"}
|
||||||
|
|
||||||
|
# Commands that mutate state and are blocked by default
|
||||||
|
_WRITE_COMMANDS = {
|
||||||
|
"INSERT",
|
||||||
|
"UPDATE",
|
||||||
|
"DELETE",
|
||||||
|
"DROP",
|
||||||
|
"ALTER",
|
||||||
|
"CREATE",
|
||||||
|
"TRUNCATE",
|
||||||
|
"GRANT",
|
||||||
|
"REVOKE",
|
||||||
|
"EXEC",
|
||||||
|
"EXECUTE",
|
||||||
|
"CALL",
|
||||||
|
"MERGE",
|
||||||
|
"REPLACE",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class NL2SQLToolInput(BaseModel):
|
class NL2SQLToolInput(BaseModel):
|
||||||
sql_query: str = Field(
|
sql_query: str = Field(
|
||||||
@@ -21,20 +46,65 @@ class NL2SQLToolInput(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class NL2SQLTool(BaseTool):
|
class NL2SQLTool(BaseTool):
|
||||||
|
"""Tool that converts natural language to SQL and executes it against a database.
|
||||||
|
|
||||||
|
By default the tool operates in **read-only mode**: only SELECT, SHOW,
|
||||||
|
DESCRIBE, EXPLAIN, and WITH (CTE) statements are permitted. Write
|
||||||
|
operations (INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, TRUNCATE, …) are
|
||||||
|
blocked unless ``allow_dml=True`` is set explicitly or the environment
|
||||||
|
variable ``CREWAI_NL2SQL_ALLOW_DML=true`` is present.
|
||||||
|
|
||||||
|
The ``_fetch_all_available_columns`` helper uses parameterised queries so
|
||||||
|
that table names coming from the database catalogue cannot be used as an
|
||||||
|
injection vector.
|
||||||
|
"""
|
||||||
|
|
||||||
name: str = "NL2SQLTool"
|
name: str = "NL2SQLTool"
|
||||||
description: str = "Converts natural language to SQL queries and executes them."
|
description: str = (
|
||||||
|
"Converts natural language to SQL queries and executes them against a "
|
||||||
|
"database. Read-only by default — only SELECT/SHOW/DESCRIBE/EXPLAIN "
|
||||||
|
"queries are allowed unless the tool is configured with allow_dml=True."
|
||||||
|
)
|
||||||
db_uri: str = Field(
|
db_uri: str = Field(
|
||||||
title="Database URI",
|
title="Database URI",
|
||||||
description="The URI of the database to connect to.",
|
description="The URI of the database to connect to.",
|
||||||
)
|
)
|
||||||
|
allow_dml: bool = Field(
|
||||||
|
default=False,
|
||||||
|
title="Allow DML",
|
||||||
|
description=(
|
||||||
|
"When False (default) only read statements are permitted. "
|
||||||
|
"Set to True to allow INSERT/UPDATE/DELETE/DROP and other "
|
||||||
|
"write operations."
|
||||||
|
),
|
||||||
|
)
|
||||||
tables: list[dict[str, Any]] = Field(default_factory=list)
|
tables: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
|
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
|
||||||
args_schema: type[BaseModel] = NL2SQLToolInput
|
args_schema: type[BaseModel] = NL2SQLToolInput
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _apply_env_override(self) -> "NL2SQLTool":
|
||||||
|
"""Allow CREWAI_NL2SQL_ALLOW_DML=true to override allow_dml at runtime."""
|
||||||
|
if os.environ.get("CREWAI_NL2SQL_ALLOW_DML", "").strip().lower() == "true":
|
||||||
|
if not self.allow_dml:
|
||||||
|
logger.warning(
|
||||||
|
"NL2SQLTool: CREWAI_NL2SQL_ALLOW_DML env var is set — "
|
||||||
|
"DML/DDL operations are enabled. Ensure this is intentional."
|
||||||
|
)
|
||||||
|
self.allow_dml = True
|
||||||
|
return self
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
if not SQLALCHEMY_AVAILABLE:
|
if not SQLALCHEMY_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
"sqlalchemy is not installed. Please install it with "
|
||||||
|
"`pip install crewai-tools[sqlalchemy]`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.allow_dml:
|
||||||
|
logger.warning(
|
||||||
|
"NL2SQLTool: allow_dml=True — write operations (INSERT/UPDATE/"
|
||||||
|
"DELETE/DROP/…) are permitted. Use with caution."
|
||||||
)
|
)
|
||||||
|
|
||||||
data: dict[str, list[dict[str, Any]] | str] = {}
|
data: dict[str, list[dict[str, Any]] | str] = {}
|
||||||
@@ -50,42 +120,122 @@ class NL2SQLTool(BaseTool):
|
|||||||
self.tables = tables
|
self.tables = tables
|
||||||
self.columns = data
|
self.columns = data
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Query validation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _validate_query(self, sql_query: str) -> None:
|
||||||
|
"""Raise ValueError if *sql_query* is not permitted under the current config.
|
||||||
|
|
||||||
|
Parses the leading SQL command keyword and checks it against the
|
||||||
|
allowed set. When ``allow_dml=False`` (the default) only read
|
||||||
|
statements pass. When ``allow_dml=True`` all statements are allowed
|
||||||
|
but a warning is emitted for write operations.
|
||||||
|
"""
|
||||||
|
command = self._extract_command(sql_query)
|
||||||
|
|
||||||
|
if command in _WRITE_COMMANDS:
|
||||||
|
if not self.allow_dml:
|
||||||
|
raise ValueError(
|
||||||
|
f"NL2SQLTool is configured in read-only mode and blocked a "
|
||||||
|
f"'{command}' statement. To allow write operations set "
|
||||||
|
f"allow_dml=True or CREWAI_NL2SQL_ALLOW_DML=true."
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"NL2SQLTool: executing write statement '%s' because allow_dml=True.",
|
||||||
|
command,
|
||||||
|
)
|
||||||
|
elif command not in _READ_ONLY_COMMANDS:
|
||||||
|
# Unknown command — block by default unless DML is explicitly enabled
|
||||||
|
if not self.allow_dml:
|
||||||
|
raise ValueError(
|
||||||
|
f"NL2SQLTool blocked an unrecognised SQL command '{command}'. "
|
||||||
|
f"Only {sorted(_READ_ONLY_COMMANDS)} are allowed in read-only "
|
||||||
|
f"mode."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_command(sql_query: str) -> str:
|
||||||
|
"""Return the uppercased first keyword of *sql_query*."""
|
||||||
|
stripped = sql_query.strip().lstrip("(")
|
||||||
|
first_token = stripped.split()[0] if stripped.split() else ""
|
||||||
|
return first_token.upper().rstrip(";")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Schema introspection helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
|
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
|
||||||
return self.execute_sql(
|
return self.execute_sql(
|
||||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
"SELECT table_name FROM information_schema.tables "
|
||||||
|
"WHERE table_schema = 'public';"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _fetch_all_available_columns(
|
def _fetch_all_available_columns(
|
||||||
self, table_name: str
|
self, table_name: str
|
||||||
) -> list[dict[str, Any]] | str:
|
) -> list[dict[str, Any]] | str:
|
||||||
|
"""Fetch columns for *table_name* using a parameterised query.
|
||||||
|
|
||||||
|
The table name is bound via SQLAlchemy's ``:param`` syntax to prevent
|
||||||
|
SQL injection from catalogue values.
|
||||||
|
"""
|
||||||
return self.execute_sql(
|
return self.execute_sql(
|
||||||
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608
|
"SELECT column_name, data_type FROM information_schema.columns "
|
||||||
|
"WHERE table_name = :table_name",
|
||||||
|
params={"table_name": table_name},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Core execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||||
try:
|
try:
|
||||||
|
self._validate_query(sql_query)
|
||||||
data = self.execute_sql(sql_query)
|
data = self.execute_sql(sql_query)
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
data = (
|
data = (
|
||||||
f"Based on these tables {self.tables} and columns {self.columns}, "
|
f"Based on these tables {self.tables} and columns {self.columns}, "
|
||||||
"you can create SQL queries to retrieve data from the database."
|
"you can create SQL queries to retrieve data from the database. "
|
||||||
f"Get the original request {sql_query} and the error {exc} and create the correct SQL query."
|
f"Get the original request {sql_query} and the error {exc} and "
|
||||||
|
"create the correct SQL query."
|
||||||
)
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def execute_sql(self, sql_query: str) -> list[dict[str, Any]] | str:
|
def execute_sql(
|
||||||
|
self,
|
||||||
|
sql_query: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> list[dict[str, Any]] | str:
|
||||||
|
"""Execute *sql_query* and return the results as a list of dicts.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sql_query:
|
||||||
|
The SQL statement to run.
|
||||||
|
params:
|
||||||
|
Optional mapping of bind parameters (e.g. ``{"table_name": "users"}``).
|
||||||
|
"""
|
||||||
if not SQLALCHEMY_AVAILABLE:
|
if not SQLALCHEMY_AVAILABLE:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
"sqlalchemy is not installed. Please install it with "
|
||||||
|
"`pip install crewai-tools[sqlalchemy]`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_write = self._extract_command(sql_query) in _WRITE_COMMANDS
|
||||||
|
|
||||||
engine = create_engine(self.db_uri)
|
engine = create_engine(self.db_uri)
|
||||||
Session = sessionmaker(bind=engine) # noqa: N806
|
Session = sessionmaker(bind=engine) # noqa: N806
|
||||||
session = Session()
|
session = Session()
|
||||||
try:
|
try:
|
||||||
result = session.execute(text(sql_query))
|
result = session.execute(text(sql_query), params or {})
|
||||||
session.commit()
|
|
||||||
|
# Only commit when the operation actually mutates state
|
||||||
|
if self.allow_dml and is_write:
|
||||||
|
session.commit()
|
||||||
|
|
||||||
if result.returns_rows: # type: ignore[attr-defined]
|
if result.returns_rows: # type: ignore[attr-defined]
|
||||||
columns = result.keys()
|
columns = result.keys()
|
||||||
|
|||||||
291
lib/crewai-tools/tests/tools/test_nl2sql_security.py
Normal file
291
lib/crewai-tools/tests/tools/test_nl2sql_security.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
"""Security tests for NL2SQLTool.
|
||||||
|
|
||||||
|
Uses an in-memory SQLite database so no external service is needed.
|
||||||
|
SQLite does not have information_schema, so we patch the schema-introspection
|
||||||
|
helpers to avoid bootstrap failures and focus purely on the security logic.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Skip the entire module if SQLAlchemy is not installed
|
||||||
|
pytest.importorskip("sqlalchemy")
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, text # noqa: E402
|
||||||
|
from sqlalchemy.orm import sessionmaker # noqa: E402
|
||||||
|
|
||||||
|
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool # noqa: E402
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
SQLITE_URI = "sqlite://" # in-memory
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool(allow_dml: bool = False, **kwargs) -> NL2SQLTool:
|
||||||
|
"""Return a NL2SQLTool wired to an in-memory SQLite DB.
|
||||||
|
|
||||||
|
Schema-introspection is patched out so we can create the tool without a
|
||||||
|
real PostgreSQL information_schema.
|
||||||
|
"""
|
||||||
|
with (
|
||||||
|
patch.object(NL2SQLTool, "_fetch_available_tables", return_value=[]),
|
||||||
|
patch.object(NL2SQLTool, "_fetch_all_available_columns", return_value=[]),
|
||||||
|
):
|
||||||
|
return NL2SQLTool(db_uri=SQLITE_URI, allow_dml=allow_dml, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_db(uri: str) -> None:
|
||||||
|
"""Create a tiny table in the target database for DML tests."""
|
||||||
|
engine = create_engine(uri)
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT)"))
|
||||||
|
conn.execute(text("INSERT INTO users VALUES (1, 'alice')"))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Read-only enforcement (allow_dml=False)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadOnlyMode:
|
||||||
|
def test_select_allowed_by_default(self):
|
||||||
|
tool = _make_tool()
|
||||||
|
# SQLite supports SELECT without information_schema
|
||||||
|
result = tool.execute_sql("SELECT 1 AS val")
|
||||||
|
assert result == [{"val": 1}]
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"stmt",
|
||||||
|
[
|
||||||
|
"INSERT INTO t VALUES (1)",
|
||||||
|
"UPDATE t SET col = 1",
|
||||||
|
"DELETE FROM t",
|
||||||
|
"DROP TABLE t",
|
||||||
|
"ALTER TABLE t ADD col TEXT",
|
||||||
|
"CREATE TABLE t (id INTEGER)",
|
||||||
|
"TRUNCATE TABLE t",
|
||||||
|
"GRANT SELECT ON t TO user1",
|
||||||
|
"REVOKE SELECT ON t FROM user1",
|
||||||
|
"EXEC sp_something",
|
||||||
|
"EXECUTE sp_something",
|
||||||
|
"CALL proc()",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_write_statements_blocked_by_default(self, stmt: str):
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
with pytest.raises(ValueError, match="read-only mode"):
|
||||||
|
tool._validate_query(stmt)
|
||||||
|
|
||||||
|
def test_explain_allowed(self):
|
||||||
|
tool = _make_tool()
|
||||||
|
# Should not raise
|
||||||
|
tool._validate_query("EXPLAIN SELECT 1")
|
||||||
|
|
||||||
|
def test_with_cte_allowed(self):
|
||||||
|
tool = _make_tool()
|
||||||
|
tool._validate_query("WITH cte AS (SELECT 1) SELECT * FROM cte")
|
||||||
|
|
||||||
|
def test_show_allowed(self):
|
||||||
|
tool = _make_tool()
|
||||||
|
tool._validate_query("SHOW TABLES")
|
||||||
|
|
||||||
|
def test_describe_allowed(self):
|
||||||
|
tool = _make_tool()
|
||||||
|
tool._validate_query("DESCRIBE users")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DML enabled (allow_dml=True)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDMLEnabled:
|
||||||
|
def test_insert_allowed_when_dml_enabled(self):
|
||||||
|
tool = _make_tool(allow_dml=True)
|
||||||
|
# Should not raise
|
||||||
|
tool._validate_query("INSERT INTO t VALUES (1)")
|
||||||
|
|
||||||
|
def test_delete_allowed_when_dml_enabled(self):
|
||||||
|
tool = _make_tool(allow_dml=True)
|
||||||
|
tool._validate_query("DELETE FROM t WHERE id = 1")
|
||||||
|
|
||||||
|
def test_drop_allowed_when_dml_enabled(self):
|
||||||
|
tool = _make_tool(allow_dml=True)
|
||||||
|
tool._validate_query("DROP TABLE t")
|
||||||
|
|
||||||
|
def test_dml_actually_persists(self):
|
||||||
|
"""End-to-end: INSERT commits when allow_dml=True."""
|
||||||
|
# Use a file-based SQLite so we can verify persistence across sessions
|
||||||
|
import tempfile, os
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||||
|
db_path = f.name
|
||||||
|
uri = f"sqlite:///{db_path}"
|
||||||
|
try:
|
||||||
|
tool = _make_tool(allow_dml=True)
|
||||||
|
tool.db_uri = uri
|
||||||
|
|
||||||
|
engine = create_engine(uri)
|
||||||
|
with engine.connect() as conn:
|
||||||
|
conn.execute(text("CREATE TABLE items (id INTEGER PRIMARY KEY)"))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
tool.execute_sql("INSERT INTO items VALUES (42)")
|
||||||
|
|
||||||
|
with engine.connect() as conn:
|
||||||
|
rows = conn.execute(text("SELECT id FROM items")).fetchall()
|
||||||
|
assert (42,) in rows
|
||||||
|
finally:
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Parameterised query — SQL injection prevention
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParameterisedQueries:
|
||||||
|
def test_table_name_is_parameterised(self):
|
||||||
|
"""_fetch_all_available_columns must not interpolate table_name into SQL."""
|
||||||
|
tool = _make_tool()
|
||||||
|
captured_calls = []
|
||||||
|
|
||||||
|
def recording_execute_sql(self_inner, sql_query, params=None):
|
||||||
|
captured_calls.append((sql_query, params))
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch.object(NL2SQLTool, "execute_sql", recording_execute_sql):
|
||||||
|
tool._fetch_all_available_columns("users'; DROP TABLE users; --")
|
||||||
|
|
||||||
|
assert len(captured_calls) == 1
|
||||||
|
sql, params = captured_calls[0]
|
||||||
|
# The raw SQL must NOT contain the injected string
|
||||||
|
assert "DROP" not in sql
|
||||||
|
# The table name must be passed as a parameter
|
||||||
|
assert params is not None
|
||||||
|
assert params.get("table_name") == "users'; DROP TABLE users; --"
|
||||||
|
# The SQL template must use the :param syntax
|
||||||
|
assert ":table_name" in sql
|
||||||
|
|
||||||
|
def test_injection_string_not_in_sql_template(self):
|
||||||
|
"""The f-string vulnerability is gone — table name never lands in the SQL."""
|
||||||
|
tool = _make_tool()
|
||||||
|
injection = "'; DROP TABLE users; --"
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def spy(self_inner, sql_query, params=None):
|
||||||
|
captured["sql"] = sql_query
|
||||||
|
captured["params"] = params
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch.object(NL2SQLTool, "execute_sql", spy):
|
||||||
|
tool._fetch_all_available_columns(injection)
|
||||||
|
|
||||||
|
assert injection not in captured["sql"]
|
||||||
|
assert captured["params"]["table_name"] == injection
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# session.commit() not called for read-only queries
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoCommitForReadOnly:
|
||||||
|
def test_select_does_not_commit(self):
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.returns_rows = True
|
||||||
|
mock_result.keys.return_value = ["val"]
|
||||||
|
mock_result.fetchall.return_value = [(1,)]
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
mock_session_cls = MagicMock(return_value=mock_session)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||||
|
patch(
|
||||||
|
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||||
|
return_value=mock_session_cls,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
tool.execute_sql("SELECT 1")
|
||||||
|
|
||||||
|
mock_session.commit.assert_not_called()
|
||||||
|
|
||||||
|
def test_write_with_dml_enabled_does_commit(self):
|
||||||
|
tool = _make_tool(allow_dml=True)
|
||||||
|
|
||||||
|
mock_session = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.returns_rows = False
|
||||||
|
mock_session.execute.return_value = mock_result
|
||||||
|
|
||||||
|
mock_session_cls = MagicMock(return_value=mock_session)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"),
|
||||||
|
patch(
|
||||||
|
"crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker",
|
||||||
|
return_value=mock_session_cls,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
tool.execute_sql("INSERT INTO t VALUES (1)")
|
||||||
|
|
||||||
|
mock_session.commit.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Environment-variable escape hatch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvVarEscapeHatch:
|
||||||
|
def test_env_var_enables_dml(self):
|
||||||
|
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "true"}):
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
assert tool.allow_dml is True
|
||||||
|
|
||||||
|
def test_env_var_case_insensitive(self):
|
||||||
|
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "TRUE"}):
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
assert tool.allow_dml is True
|
||||||
|
|
||||||
|
def test_env_var_absent_keeps_default(self):
|
||||||
|
env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
assert tool.allow_dml is False
|
||||||
|
|
||||||
|
def test_env_var_false_does_not_enable_dml(self):
|
||||||
|
with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "false"}):
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
assert tool.allow_dml is False
|
||||||
|
|
||||||
|
def test_dml_write_blocked_without_env_var(self):
|
||||||
|
env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"}
|
||||||
|
with patch.dict(os.environ, env, clear=True):
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
with pytest.raises(ValueError, match="read-only mode"):
|
||||||
|
tool._validate_query("DROP TABLE sensitive_data")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _run() propagates ValueError from _validate_query
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunValidation:
|
||||||
|
def test_run_raises_on_blocked_query(self):
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
with pytest.raises(ValueError, match="read-only mode"):
|
||||||
|
tool._run("DELETE FROM users")
|
||||||
|
|
||||||
|
def test_run_returns_results_for_select(self):
|
||||||
|
tool = _make_tool(allow_dml=False)
|
||||||
|
result = tool._run("SELECT 1 AS n")
|
||||||
|
assert result == [{"n": 1}]
|
||||||
Reference in New Issue
Block a user