Files
crewAI/lib/crewai-tools/tests/tools/nl2sql_tool_test.py
Devin AI 9a9cb48d09 fix: prevent SQL injection in SnowflakeSearchTool and NL2SQLTool
- Add identifier validation regex to database and snowflake_schema fields
  in SnowflakeSearchToolInput to reject malicious values at schema level
- Add _validate_identifier() runtime check in SnowflakeSearchTool._run()
  and double-quote identifiers in USE DATABASE/SCHEMA SQL statements
- Add _validate_identifier() to NL2SQLTool to sanitize table_name in
  _fetch_all_available_columns() preventing second-order SQL injection
- Add comprehensive tests for both tools covering injection vectors

Closes #4993

Co-Authored-By: João <joao@crewai.com>
2026-03-20 20:15:14 +00:00

73 lines
3.0 KiB
Python

from unittest.mock import MagicMock, patch
import pytest
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
class TestNL2SQLToolValidateIdentifier:
"""Tests for SQL injection prevention via identifier validation."""
def test_valid_identifiers(self):
assert NL2SQLTool._validate_identifier("users", "table_name") == "users"
assert NL2SQLTool._validate_identifier("MY_TABLE", "table_name") == "MY_TABLE"
assert NL2SQLTool._validate_identifier("table$1", "table_name") == "table$1"
assert NL2SQLTool._validate_identifier("_private", "table_name") == "_private"
def test_rejects_sql_injection_with_semicolon(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("users; DROP TABLE users;--", "table_name")
def test_rejects_sql_injection_with_quotes(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("users'--", "table_name")
def test_rejects_sql_injection_with_spaces(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("users DROP TABLE", "table_name")
def test_rejects_leading_number(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("1table", "table_name")
def test_rejects_empty_string(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("", "table_name")
def test_rejects_parentheses(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("users()", "table_name")
def test_rejects_dash_comment(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("users--comment", "table_name")
@patch("crewai_tools.tools.nl2sql.nl2sql_tool.SQLALCHEMY_AVAILABLE", True)
class TestNL2SQLToolFetchColumns:
"""Tests that _fetch_all_available_columns validates table names."""
def _make_tool(self):
"""Create an NL2SQLTool instance bypassing model_post_init DB calls."""
with patch.object(NL2SQLTool, "model_post_init"):
tool = NL2SQLTool(
db_uri="sqlite:///:memory:",
name="NL2SQLTool",
description="test",
)
return tool
def test_rejects_malicious_table_name(self):
tool = self._make_tool()
with pytest.raises(ValueError, match="Invalid table_name"):
tool._fetch_all_available_columns("users'; DROP TABLE users;--")
def test_accepts_valid_table_name(self):
tool = self._make_tool()
with patch.object(NL2SQLTool, "execute_sql", return_value=[]) as mock_exec:
result = tool._fetch_all_available_columns("valid_table")
mock_exec.assert_called_once()
call_sql = mock_exec.call_args[0][0]
assert "valid_table" in call_sql
assert result == []