diff --git a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py index bfb9c02dd..84b4bd772 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py @@ -1,7 +1,9 @@ +import logging +import os from typing import Any from crewai.tools import BaseTool -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator try: @@ -12,6 +14,29 @@ try: except ImportError: SQLALCHEMY_AVAILABLE = False +logger = logging.getLogger(__name__) + +# Commands allowed in read-only mode +_READ_ONLY_COMMANDS = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN", "WITH"} + +# Commands that mutate state and are blocked by default +_WRITE_COMMANDS = { + "INSERT", + "UPDATE", + "DELETE", + "DROP", + "ALTER", + "CREATE", + "TRUNCATE", + "GRANT", + "REVOKE", + "EXEC", + "EXECUTE", + "CALL", + "MERGE", + "REPLACE", +} + class NL2SQLToolInput(BaseModel): sql_query: str = Field( @@ -21,20 +46,65 @@ class NL2SQLToolInput(BaseModel): class NL2SQLTool(BaseTool): + """Tool that converts natural language to SQL and executes it against a database. + + By default the tool operates in **read-only mode**: only SELECT, SHOW, + DESCRIBE, EXPLAIN, and WITH (CTE) statements are permitted. Write + operations (INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, TRUNCATE, …) are + blocked unless ``allow_dml=True`` is set explicitly or the environment + variable ``CREWAI_NL2SQL_ALLOW_DML=true`` is present. + + The ``_fetch_all_available_columns`` helper uses parameterised queries so + that table names coming from the database catalogue cannot be used as an + injection vector. + """ + name: str = "NL2SQLTool" - description: str = "Converts natural language to SQL queries and executes them." + description: str = ( + "Converts natural language to SQL queries and executes them against a " + "database. Read-only by default — only SELECT/SHOW/DESCRIBE/EXPLAIN " + "queries are allowed unless the tool is configured with allow_dml=True." + ) db_uri: str = Field( title="Database URI", description="The URI of the database to connect to.", ) + allow_dml: bool = Field( + default=False, + title="Allow DML", + description=( + "When False (default) only read statements are permitted. " + "Set to True to allow INSERT/UPDATE/DELETE/DROP and other " + "write operations." + ), + ) tables: list[dict[str, Any]] = Field(default_factory=list) columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict) args_schema: type[BaseModel] = NL2SQLToolInput + @model_validator(mode="after") + def _apply_env_override(self) -> "NL2SQLTool": + """Allow CREWAI_NL2SQL_ALLOW_DML=true to override allow_dml at runtime.""" + if os.environ.get("CREWAI_NL2SQL_ALLOW_DML", "").strip().lower() == "true": + if not self.allow_dml: + logger.warning( + "NL2SQLTool: CREWAI_NL2SQL_ALLOW_DML env var is set — " + "DML/DDL operations are enabled. Ensure this is intentional." + ) + self.allow_dml = True + return self + def model_post_init(self, __context: Any) -> None: if not SQLALCHEMY_AVAILABLE: raise ImportError( - "sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`" + "sqlalchemy is not installed. Please install it with " + "`pip install crewai-tools[sqlalchemy]`" + ) + + if self.allow_dml: + logger.warning( + "NL2SQLTool: allow_dml=True — write operations (INSERT/UPDATE/" + "DELETE/DROP/…) are permitted. Use with caution." ) data: dict[str, list[dict[str, Any]] | str] = {} @@ -50,42 +120,122 @@ class NL2SQLTool(BaseTool): self.tables = tables self.columns = data + # ------------------------------------------------------------------ + # Query validation + # ------------------------------------------------------------------ + + def _validate_query(self, sql_query: str) -> None: + """Raise ValueError if *sql_query* is not permitted under the current config. + + Parses the leading SQL command keyword and checks it against the + allowed set. When ``allow_dml=False`` (the default) only read + statements pass. When ``allow_dml=True`` all statements are allowed + but a warning is emitted for write operations. + """ + command = self._extract_command(sql_query) + + if command in _WRITE_COMMANDS: + if not self.allow_dml: + raise ValueError( + f"NL2SQLTool is configured in read-only mode and blocked a " + f"'{command}' statement. To allow write operations set " + f"allow_dml=True or CREWAI_NL2SQL_ALLOW_DML=true." + ) + logger.warning( + "NL2SQLTool: executing write statement '%s' because allow_dml=True.", + command, + ) + elif command not in _READ_ONLY_COMMANDS: + # Unknown command — block by default unless DML is explicitly enabled + if not self.allow_dml: + raise ValueError( + f"NL2SQLTool blocked an unrecognised SQL command '{command}'. " + f"Only {sorted(_READ_ONLY_COMMANDS)} are allowed in read-only " + f"mode." + ) + + @staticmethod + def _extract_command(sql_query: str) -> str: + """Return the uppercased first keyword of *sql_query*.""" + stripped = sql_query.strip().lstrip("(") + first_token = stripped.split()[0] if stripped.split() else "" + return first_token.upper().rstrip(";") + + # ------------------------------------------------------------------ + # Schema introspection helpers + # ------------------------------------------------------------------ + def _fetch_available_tables(self) -> list[dict[str, Any]] | str: return self.execute_sql( - "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = 'public';" ) def _fetch_all_available_columns( self, table_name: str ) -> list[dict[str, Any]] | str: + """Fetch columns for *table_name* using a parameterised query. + + The table name is bound via SQLAlchemy's ``:param`` syntax to prevent + SQL injection from catalogue values. + """ return self.execute_sql( - f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608 + "SELECT column_name, data_type FROM information_schema.columns " + "WHERE table_name = :table_name", + params={"table_name": table_name}, ) + # ------------------------------------------------------------------ + # Core execution + # ------------------------------------------------------------------ + def _run(self, sql_query: str) -> list[dict[str, Any]] | str: try: + self._validate_query(sql_query) data = self.execute_sql(sql_query) + except ValueError: + raise except Exception as exc: data = ( f"Based on these tables {self.tables} and columns {self.columns}, " - "you can create SQL queries to retrieve data from the database." - f"Get the original request {sql_query} and the error {exc} and create the correct SQL query." + "you can create SQL queries to retrieve data from the database. " + f"Get the original request {sql_query} and the error {exc} and " + "create the correct SQL query." ) return data - def execute_sql(self, sql_query: str) -> list[dict[str, Any]] | str: + def execute_sql( + self, + sql_query: str, + params: dict[str, Any] | None = None, + ) -> list[dict[str, Any]] | str: + """Execute *sql_query* and return the results as a list of dicts. + + Parameters + ---------- + sql_query: + The SQL statement to run. + params: + Optional mapping of bind parameters (e.g. ``{"table_name": "users"}``). + """ if not SQLALCHEMY_AVAILABLE: raise ImportError( - "sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`" + "sqlalchemy is not installed. Please install it with " + "`pip install crewai-tools[sqlalchemy]`" ) + is_write = self._extract_command(sql_query) in _WRITE_COMMANDS + engine = create_engine(self.db_uri) Session = sessionmaker(bind=engine) # noqa: N806 session = Session() try: - result = session.execute(text(sql_query)) - session.commit() + result = session.execute(text(sql_query), params or {}) + + # Only commit when the operation actually mutates state + if self.allow_dml and is_write: + session.commit() if result.returns_rows: # type: ignore[attr-defined] columns = result.keys() diff --git a/lib/crewai-tools/tests/tools/test_nl2sql_security.py b/lib/crewai-tools/tests/tools/test_nl2sql_security.py new file mode 100644 index 000000000..f13a7c8ea --- /dev/null +++ b/lib/crewai-tools/tests/tools/test_nl2sql_security.py @@ -0,0 +1,291 @@ +"""Security tests for NL2SQLTool. + +Uses an in-memory SQLite database so no external service is needed. +SQLite does not have information_schema, so we patch the schema-introspection +helpers to avoid bootstrap failures and focus purely on the security logic. +""" +import os +from unittest.mock import MagicMock, patch + +import pytest + +# Skip the entire module if SQLAlchemy is not installed +pytest.importorskip("sqlalchemy") + +from sqlalchemy import create_engine, text # noqa: E402 +from sqlalchemy.orm import sessionmaker # noqa: E402 + +from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool # noqa: E402 + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +SQLITE_URI = "sqlite://" # in-memory + + +def _make_tool(allow_dml: bool = False, **kwargs) -> NL2SQLTool: + """Return a NL2SQLTool wired to an in-memory SQLite DB. + + Schema-introspection is patched out so we can create the tool without a + real PostgreSQL information_schema. + """ + with ( + patch.object(NL2SQLTool, "_fetch_available_tables", return_value=[]), + patch.object(NL2SQLTool, "_fetch_all_available_columns", return_value=[]), + ): + return NL2SQLTool(db_uri=SQLITE_URI, allow_dml=allow_dml, **kwargs) + + +def _seed_db(uri: str) -> None: + """Create a tiny table in the target database for DML tests.""" + engine = create_engine(uri) + with engine.connect() as conn: + conn.execute(text("CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY, name TEXT)")) + conn.execute(text("INSERT INTO users VALUES (1, 'alice')")) + conn.commit() + + +# --------------------------------------------------------------------------- +# Read-only enforcement (allow_dml=False) +# --------------------------------------------------------------------------- + + +class TestReadOnlyMode: + def test_select_allowed_by_default(self): + tool = _make_tool() + # SQLite supports SELECT without information_schema + result = tool.execute_sql("SELECT 1 AS val") + assert result == [{"val": 1}] + + @pytest.mark.parametrize( + "stmt", + [ + "INSERT INTO t VALUES (1)", + "UPDATE t SET col = 1", + "DELETE FROM t", + "DROP TABLE t", + "ALTER TABLE t ADD col TEXT", + "CREATE TABLE t (id INTEGER)", + "TRUNCATE TABLE t", + "GRANT SELECT ON t TO user1", + "REVOKE SELECT ON t FROM user1", + "EXEC sp_something", + "EXECUTE sp_something", + "CALL proc()", + ], + ) + def test_write_statements_blocked_by_default(self, stmt: str): + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query(stmt) + + def test_explain_allowed(self): + tool = _make_tool() + # Should not raise + tool._validate_query("EXPLAIN SELECT 1") + + def test_with_cte_allowed(self): + tool = _make_tool() + tool._validate_query("WITH cte AS (SELECT 1) SELECT * FROM cte") + + def test_show_allowed(self): + tool = _make_tool() + tool._validate_query("SHOW TABLES") + + def test_describe_allowed(self): + tool = _make_tool() + tool._validate_query("DESCRIBE users") + + +# --------------------------------------------------------------------------- +# DML enabled (allow_dml=True) +# --------------------------------------------------------------------------- + + +class TestDMLEnabled: + def test_insert_allowed_when_dml_enabled(self): + tool = _make_tool(allow_dml=True) + # Should not raise + tool._validate_query("INSERT INTO t VALUES (1)") + + def test_delete_allowed_when_dml_enabled(self): + tool = _make_tool(allow_dml=True) + tool._validate_query("DELETE FROM t WHERE id = 1") + + def test_drop_allowed_when_dml_enabled(self): + tool = _make_tool(allow_dml=True) + tool._validate_query("DROP TABLE t") + + def test_dml_actually_persists(self): + """End-to-end: INSERT commits when allow_dml=True.""" + # Use a file-based SQLite so we can verify persistence across sessions + import tempfile, os + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + uri = f"sqlite:///{db_path}" + try: + tool = _make_tool(allow_dml=True) + tool.db_uri = uri + + engine = create_engine(uri) + with engine.connect() as conn: + conn.execute(text("CREATE TABLE items (id INTEGER PRIMARY KEY)")) + conn.commit() + + tool.execute_sql("INSERT INTO items VALUES (42)") + + with engine.connect() as conn: + rows = conn.execute(text("SELECT id FROM items")).fetchall() + assert (42,) in rows + finally: + os.unlink(db_path) + + +# --------------------------------------------------------------------------- +# Parameterised query — SQL injection prevention +# --------------------------------------------------------------------------- + + +class TestParameterisedQueries: + def test_table_name_is_parameterised(self): + """_fetch_all_available_columns must not interpolate table_name into SQL.""" + tool = _make_tool() + captured_calls = [] + + def recording_execute_sql(self_inner, sql_query, params=None): + captured_calls.append((sql_query, params)) + return [] + + with patch.object(NL2SQLTool, "execute_sql", recording_execute_sql): + tool._fetch_all_available_columns("users'; DROP TABLE users; --") + + assert len(captured_calls) == 1 + sql, params = captured_calls[0] + # The raw SQL must NOT contain the injected string + assert "DROP" not in sql + # The table name must be passed as a parameter + assert params is not None + assert params.get("table_name") == "users'; DROP TABLE users; --" + # The SQL template must use the :param syntax + assert ":table_name" in sql + + def test_injection_string_not_in_sql_template(self): + """The f-string vulnerability is gone — table name never lands in the SQL.""" + tool = _make_tool() + injection = "'; DROP TABLE users; --" + captured = {} + + def spy(self_inner, sql_query, params=None): + captured["sql"] = sql_query + captured["params"] = params + return [] + + with patch.object(NL2SQLTool, "execute_sql", spy): + tool._fetch_all_available_columns(injection) + + assert injection not in captured["sql"] + assert captured["params"]["table_name"] == injection + + +# --------------------------------------------------------------------------- +# session.commit() not called for read-only queries +# --------------------------------------------------------------------------- + + +class TestNoCommitForReadOnly: + def test_select_does_not_commit(self): + tool = _make_tool(allow_dml=False) + + mock_session = MagicMock() + mock_result = MagicMock() + mock_result.returns_rows = True + mock_result.keys.return_value = ["val"] + mock_result.fetchall.return_value = [(1,)] + mock_session.execute.return_value = mock_result + + mock_session_cls = MagicMock(return_value=mock_session) + + with ( + patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"), + patch( + "crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker", + return_value=mock_session_cls, + ), + ): + tool.execute_sql("SELECT 1") + + mock_session.commit.assert_not_called() + + def test_write_with_dml_enabled_does_commit(self): + tool = _make_tool(allow_dml=True) + + mock_session = MagicMock() + mock_result = MagicMock() + mock_result.returns_rows = False + mock_session.execute.return_value = mock_result + + mock_session_cls = MagicMock(return_value=mock_session) + + with ( + patch("crewai_tools.tools.nl2sql.nl2sql_tool.create_engine"), + patch( + "crewai_tools.tools.nl2sql.nl2sql_tool.sessionmaker", + return_value=mock_session_cls, + ), + ): + tool.execute_sql("INSERT INTO t VALUES (1)") + + mock_session.commit.assert_called_once() + + +# --------------------------------------------------------------------------- +# Environment-variable escape hatch +# --------------------------------------------------------------------------- + + +class TestEnvVarEscapeHatch: + def test_env_var_enables_dml(self): + with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "true"}): + tool = _make_tool(allow_dml=False) + assert tool.allow_dml is True + + def test_env_var_case_insensitive(self): + with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "TRUE"}): + tool = _make_tool(allow_dml=False) + assert tool.allow_dml is True + + def test_env_var_absent_keeps_default(self): + env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"} + with patch.dict(os.environ, env, clear=True): + tool = _make_tool(allow_dml=False) + assert tool.allow_dml is False + + def test_env_var_false_does_not_enable_dml(self): + with patch.dict(os.environ, {"CREWAI_NL2SQL_ALLOW_DML": "false"}): + tool = _make_tool(allow_dml=False) + assert tool.allow_dml is False + + def test_dml_write_blocked_without_env_var(self): + env = {k: v for k, v in os.environ.items() if k != "CREWAI_NL2SQL_ALLOW_DML"} + with patch.dict(os.environ, env, clear=True): + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._validate_query("DROP TABLE sensitive_data") + + +# --------------------------------------------------------------------------- +# _run() propagates ValueError from _validate_query +# --------------------------------------------------------------------------- + + +class TestRunValidation: + def test_run_raises_on_blocked_query(self): + tool = _make_tool(allow_dml=False) + with pytest.raises(ValueError, match="read-only mode"): + tool._run("DELETE FROM users") + + def test_run_returns_results_for_select(self): + tool = _make_tool(allow_dml=False) + result = tool._run("SELECT 1 AS n") + assert result == [{"n": 1}]