mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-04 08:42:38 +00:00
fix: prevent SQL injection in SnowflakeSearchTool and NL2SQLTool
- Add identifier validation regex to database and snowflake_schema fields in SnowflakeSearchToolInput to reject malicious values at schema level - Add _validate_identifier() runtime check in SnowflakeSearchTool._run() and double-quote identifiers in USE DATABASE/SCHEMA SQL statements - Add _validate_identifier() to NL2SQLTool to sanitize table_name in _fetch_all_available_columns() preventing second-order SQL injection - Add comprehensive tests for both tools covering injection vectors Closes #4993 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
72
lib/crewai-tools/tests/tools/nl2sql_tool_test.py
Normal file
72
lib/crewai-tools/tests/tools/nl2sql_tool_test.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
|
||||
|
||||
|
||||
class TestNL2SQLToolValidateIdentifier:
|
||||
"""Tests for SQL injection prevention via identifier validation."""
|
||||
|
||||
def test_valid_identifiers(self):
|
||||
assert NL2SQLTool._validate_identifier("users", "table_name") == "users"
|
||||
assert NL2SQLTool._validate_identifier("MY_TABLE", "table_name") == "MY_TABLE"
|
||||
assert NL2SQLTool._validate_identifier("table$1", "table_name") == "table$1"
|
||||
assert NL2SQLTool._validate_identifier("_private", "table_name") == "_private"
|
||||
|
||||
def test_rejects_sql_injection_with_semicolon(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users; DROP TABLE users;--", "table_name")
|
||||
|
||||
def test_rejects_sql_injection_with_quotes(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users'--", "table_name")
|
||||
|
||||
def test_rejects_sql_injection_with_spaces(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users DROP TABLE", "table_name")
|
||||
|
||||
def test_rejects_leading_number(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("1table", "table_name")
|
||||
|
||||
def test_rejects_empty_string(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("", "table_name")
|
||||
|
||||
def test_rejects_parentheses(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users()", "table_name")
|
||||
|
||||
def test_rejects_dash_comment(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users--comment", "table_name")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.nl2sql.nl2sql_tool.SQLALCHEMY_AVAILABLE", True)
|
||||
class TestNL2SQLToolFetchColumns:
|
||||
"""Tests that _fetch_all_available_columns validates table names."""
|
||||
|
||||
def _make_tool(self):
|
||||
"""Create an NL2SQLTool instance bypassing model_post_init DB calls."""
|
||||
with patch.object(NL2SQLTool, "model_post_init"):
|
||||
tool = NL2SQLTool(
|
||||
db_uri="sqlite:///:memory:",
|
||||
name="NL2SQLTool",
|
||||
description="test",
|
||||
)
|
||||
return tool
|
||||
|
||||
def test_rejects_malicious_table_name(self):
|
||||
tool = self._make_tool()
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
tool._fetch_all_available_columns("users'; DROP TABLE users;--")
|
||||
|
||||
def test_accepts_valid_table_name(self):
|
||||
tool = self._make_tool()
|
||||
with patch.object(NL2SQLTool, "execute_sql", return_value=[]) as mock_exec:
|
||||
result = tool._fetch_all_available_columns("valid_table")
|
||||
mock_exec.assert_called_once()
|
||||
call_sql = mock_exec.call_args[0][0]
|
||||
assert "valid_table" in call_sql
|
||||
assert result == []
|
||||
Reference in New Issue
Block a user