mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-10 04:52:40 +00:00
- Add identifier validation regex to database and snowflake_schema fields in SnowflakeSearchToolInput to reject malicious values at schema level - Add _validate_identifier() runtime check in SnowflakeSearchTool._run() and double-quote identifiers in USE DATABASE/SCHEMA SQL statements - Add _validate_identifier() to NL2SQLTool to sanitize table_name in _fetch_all_available_columns() preventing second-order SQL injection - Add comprehensive tests for both tools covering injection vectors Closes #4993 Co-Authored-By: João <joao@crewai.com>
73 lines
3.0 KiB
Python
73 lines
3.0 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
|
|
|
|
|
|
class TestNL2SQLToolValidateIdentifier:
|
|
"""Tests for SQL injection prevention via identifier validation."""
|
|
|
|
def test_valid_identifiers(self):
|
|
assert NL2SQLTool._validate_identifier("users", "table_name") == "users"
|
|
assert NL2SQLTool._validate_identifier("MY_TABLE", "table_name") == "MY_TABLE"
|
|
assert NL2SQLTool._validate_identifier("table$1", "table_name") == "table$1"
|
|
assert NL2SQLTool._validate_identifier("_private", "table_name") == "_private"
|
|
|
|
def test_rejects_sql_injection_with_semicolon(self):
|
|
with pytest.raises(ValueError, match="Invalid table_name"):
|
|
NL2SQLTool._validate_identifier("users; DROP TABLE users;--", "table_name")
|
|
|
|
def test_rejects_sql_injection_with_quotes(self):
|
|
with pytest.raises(ValueError, match="Invalid table_name"):
|
|
NL2SQLTool._validate_identifier("users'--", "table_name")
|
|
|
|
def test_rejects_sql_injection_with_spaces(self):
|
|
with pytest.raises(ValueError, match="Invalid table_name"):
|
|
NL2SQLTool._validate_identifier("users DROP TABLE", "table_name")
|
|
|
|
def test_rejects_leading_number(self):
|
|
with pytest.raises(ValueError, match="Invalid table_name"):
|
|
NL2SQLTool._validate_identifier("1table", "table_name")
|
|
|
|
def test_rejects_empty_string(self):
|
|
with pytest.raises(ValueError, match="Invalid table_name"):
|
|
NL2SQLTool._validate_identifier("", "table_name")
|
|
|
|
def test_rejects_parentheses(self):
|
|
with pytest.raises(ValueError, match="Invalid table_name"):
|
|
NL2SQLTool._validate_identifier("users()", "table_name")
|
|
|
|
def test_rejects_dash_comment(self):
|
|
with pytest.raises(ValueError, match="Invalid table_name"):
|
|
NL2SQLTool._validate_identifier("users--comment", "table_name")
|
|
|
|
|
|
@patch("crewai_tools.tools.nl2sql.nl2sql_tool.SQLALCHEMY_AVAILABLE", True)
|
|
class TestNL2SQLToolFetchColumns:
|
|
"""Tests that _fetch_all_available_columns validates table names."""
|
|
|
|
def _make_tool(self):
|
|
"""Create an NL2SQLTool instance bypassing model_post_init DB calls."""
|
|
with patch.object(NL2SQLTool, "model_post_init"):
|
|
tool = NL2SQLTool(
|
|
db_uri="sqlite:///:memory:",
|
|
name="NL2SQLTool",
|
|
description="test",
|
|
)
|
|
return tool
|
|
|
|
def test_rejects_malicious_table_name(self):
|
|
tool = self._make_tool()
|
|
with pytest.raises(ValueError, match="Invalid table_name"):
|
|
tool._fetch_all_available_columns("users'; DROP TABLE users;--")
|
|
|
|
def test_accepts_valid_table_name(self):
|
|
tool = self._make_tool()
|
|
with patch.object(NL2SQLTool, "execute_sql", return_value=[]) as mock_exec:
|
|
result = tool._fetch_all_available_columns("valid_table")
|
|
mock_exec.assert_called_once()
|
|
call_sql = mock_exec.call_args[0][0]
|
|
assert "valid_table" in call_sql
|
|
assert result == []
|