Compare commits

...

2 Commits

Author SHA1 Message Date
github-actions[bot]
287ffe2f6d chore: update tool specifications 2026-03-20 20:17:10 +00:00
Devin AI
9a9cb48d09 fix: prevent SQL injection in SnowflakeSearchTool and NL2SQLTool
- Add identifier validation regex to database and snowflake_schema fields
  in SnowflakeSearchToolInput to reject malicious values at schema level
- Add _validate_identifier() runtime check in SnowflakeSearchTool._run()
  and double-quote identifiers in USE DATABASE/SCHEMA SQL statements
- Add _validate_identifier() to NL2SQLTool to sanitize table_name in
  _fetch_all_available_columns() preventing second-order SQL injection
- Add comprehensive tests for both tools covering injection vectors

Closes #4993

Co-Authored-By: João <joao@crewai.com>
2026-03-20 20:15:14 +00:00
5 changed files with 247 additions and 4 deletions

View File

@@ -1,3 +1,4 @@
import re
from typing import Any
from crewai.tools import BaseTool
@@ -52,7 +53,18 @@ class NL2SQLTool(BaseTool):
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
)
@staticmethod
def _validate_identifier(value: str, name: str) -> str:
"""Validate a SQL identifier to prevent SQL injection."""
if not re.match(r"^[A-Za-z_][A-Za-z0-9_$]*$", value):
raise ValueError(
f"Invalid {name}: {value!r}. "
f"Only alphanumeric characters, underscores, and dollar signs are allowed."
)
return value
def _fetch_all_available_columns(self, table_name: str):
table_name = self._validate_identifier(table_name, "table_name")
return self.execute_sql(
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608
)

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import re
import threading
from typing import TYPE_CHECKING, Any
@@ -71,8 +72,16 @@ class SnowflakeSearchToolInput(BaseModel):
model_config = ConfigDict(protected_namespaces=())
query: str = Field(..., description="SQL query or semantic search query to execute")
database: str | None = Field(None, description="Override default database")
snowflake_schema: str | None = Field(None, description="Override default schema")
database: str | None = Field(
None,
description="Override default database",
pattern=r"^[A-Za-z_][A-Za-z0-9_$]*$",
)
snowflake_schema: str | None = Field(
None,
description="Override default schema",
pattern=r"^[A-Za-z_][A-Za-z0-9_$]*$",
)
timeout: int | None = Field(300, description="Query timeout in seconds")
@@ -247,6 +256,16 @@ class SnowflakeSearchTool(BaseTool):
continue
raise RuntimeError("Query failed after all retries")
@staticmethod
def _validate_identifier(value: str, name: str) -> str:
"""Validate and sanitize a Snowflake identifier to prevent SQL injection."""
if not re.match(r"^[A-Za-z_][A-Za-z0-9_$]*$", value):
raise ValueError(
f"Invalid {name}: {value!r}. "
f"Only alphanumeric characters, underscores, and dollar signs are allowed."
)
return value
async def _run(
self,
query: str,
@@ -259,9 +278,11 @@ class SnowflakeSearchTool(BaseTool):
try:
# Override database/schema if provided
if database:
await self._execute_query(f"USE DATABASE {database}")
database = self._validate_identifier(database, "database")
await self._execute_query(f'USE DATABASE "{database}"')
if snowflake_schema:
await self._execute_query(f"USE SCHEMA {snowflake_schema}")
snowflake_schema = self._validate_identifier(snowflake_schema, "schema")
await self._execute_query(f'USE SCHEMA "{snowflake_schema}"')
return await self._execute_query(query, timeout)
except Exception as e:

View File

@@ -0,0 +1,72 @@
from unittest.mock import MagicMock, patch
import pytest
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
class TestNL2SQLToolValidateIdentifier:
"""Tests for SQL injection prevention via identifier validation."""
def test_valid_identifiers(self):
assert NL2SQLTool._validate_identifier("users", "table_name") == "users"
assert NL2SQLTool._validate_identifier("MY_TABLE", "table_name") == "MY_TABLE"
assert NL2SQLTool._validate_identifier("table$1", "table_name") == "table$1"
assert NL2SQLTool._validate_identifier("_private", "table_name") == "_private"
def test_rejects_sql_injection_with_semicolon(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("users; DROP TABLE users;--", "table_name")
def test_rejects_sql_injection_with_quotes(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("users'--", "table_name")
def test_rejects_sql_injection_with_spaces(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("users DROP TABLE", "table_name")
def test_rejects_leading_number(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("1table", "table_name")
def test_rejects_empty_string(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("", "table_name")
def test_rejects_parentheses(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("users()", "table_name")
def test_rejects_dash_comment(self):
with pytest.raises(ValueError, match="Invalid table_name"):
NL2SQLTool._validate_identifier("users--comment", "table_name")
@patch("crewai_tools.tools.nl2sql.nl2sql_tool.SQLALCHEMY_AVAILABLE", True)
class TestNL2SQLToolFetchColumns:
"""Tests that _fetch_all_available_columns validates table names."""
def _make_tool(self):
"""Create an NL2SQLTool instance bypassing model_post_init DB calls."""
with patch.object(NL2SQLTool, "model_post_init"):
tool = NL2SQLTool(
db_uri="sqlite:///:memory:",
name="NL2SQLTool",
description="test",
)
return tool
def test_rejects_malicious_table_name(self):
tool = self._make_tool()
with pytest.raises(ValueError, match="Invalid table_name"):
tool._fetch_all_available_columns("users'; DROP TABLE users;--")
def test_accepts_valid_table_name(self):
tool = self._make_tool()
with patch.object(NL2SQLTool, "execute_sql", return_value=[]) as mock_exec:
result = tool._fetch_all_available_columns("valid_table")
mock_exec.assert_called_once()
call_sql = mock_exec.call_args[0][0]
assert "valid_table" in call_sql
assert result == []

View File

@@ -2,6 +2,9 @@ import asyncio
from unittest.mock import MagicMock, patch
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
from crewai_tools.tools.snowflake_search_tool.snowflake_search_tool import (
SnowflakeSearchToolInput,
)
import pytest
@@ -100,3 +103,136 @@ def test_config_validation():
# Test missing authentication
with pytest.raises(ValueError):
SnowflakeConfig(account="test_account", user="test_user")
# SQL Injection Prevention Tests
class TestSnowflakeSearchToolInputValidation:
"""Tests for SQL injection prevention via input schema validation."""
def test_valid_database_identifier(self):
inp = SnowflakeSearchToolInput(query="SELECT 1", database="my_database")
assert inp.database == "my_database"
def test_valid_schema_identifier(self):
inp = SnowflakeSearchToolInput(query="SELECT 1", snowflake_schema="public")
assert inp.snowflake_schema == "public"
def test_valid_identifier_with_dollar_sign(self):
inp = SnowflakeSearchToolInput(query="SELECT 1", database="my$db")
assert inp.database == "my$db"
def test_database_with_sql_injection_semicolon(self):
with pytest.raises(ValueError):
SnowflakeSearchToolInput(
query="SELECT 1", database="test_db; DROP TABLE users; --"
)
def test_schema_with_sql_injection_semicolon(self):
with pytest.raises(ValueError):
SnowflakeSearchToolInput(
query="SELECT 1", snowflake_schema="public; DROP TABLE users; --"
)
def test_database_with_sql_injection_spaces(self):
with pytest.raises(ValueError):
SnowflakeSearchToolInput(
query="SELECT 1", database="test_db DROP TABLE"
)
def test_schema_with_sql_injection_quotes(self):
with pytest.raises(ValueError):
SnowflakeSearchToolInput(
query="SELECT 1", snowflake_schema="public\"--"
)
def test_database_with_sql_injection_dash_comment(self):
with pytest.raises(ValueError):
SnowflakeSearchToolInput(
query="SELECT 1", database="test--comment"
)
def test_database_starting_with_number(self):
with pytest.raises(ValueError):
SnowflakeSearchToolInput(query="SELECT 1", database="1invalid")
def test_none_database_is_allowed(self):
inp = SnowflakeSearchToolInput(query="SELECT 1", database=None)
assert inp.database is None
def test_none_schema_is_allowed(self):
inp = SnowflakeSearchToolInput(query="SELECT 1", snowflake_schema=None)
assert inp.snowflake_schema is None
class TestSnowflakeSearchToolValidateIdentifier:
"""Tests for the _validate_identifier runtime check."""
def test_valid_identifiers(self):
assert SnowflakeSearchTool._validate_identifier("my_db", "database") == "my_db"
assert SnowflakeSearchTool._validate_identifier("PROD_DB", "database") == "PROD_DB"
assert SnowflakeSearchTool._validate_identifier("schema$1", "schema") == "schema$1"
assert SnowflakeSearchTool._validate_identifier("_private", "schema") == "_private"
def test_rejects_semicolons(self):
with pytest.raises(ValueError, match="Invalid database"):
SnowflakeSearchTool._validate_identifier("db; DROP TABLE users;--", "database")
def test_rejects_spaces(self):
with pytest.raises(ValueError, match="Invalid schema"):
SnowflakeSearchTool._validate_identifier("public schema", "schema")
def test_rejects_quotes(self):
with pytest.raises(ValueError, match="Invalid database"):
SnowflakeSearchTool._validate_identifier('db"--', "database")
def test_rejects_leading_number(self):
with pytest.raises(ValueError, match="Invalid database"):
SnowflakeSearchTool._validate_identifier("1db", "database")
def test_rejects_empty_string(self):
with pytest.raises(ValueError, match="Invalid database"):
SnowflakeSearchTool._validate_identifier("", "database")
@pytest.mark.asyncio
async def test_run_uses_quoted_identifiers(snowflake_tool, mock_snowflake_connection):
"""Verify that _run wraps database/schema in double quotes in the SQL."""
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
await snowflake_tool._run(
query="SELECT 1",
database="my_db",
snowflake_schema="my_schema",
)
calls = mock_snowflake_connection.cursor().execute.call_args_list
sql_statements = [call[0][0] for call in calls]
assert 'USE DATABASE "my_db"' in sql_statements
assert 'USE SCHEMA "my_schema"' in sql_statements
@pytest.mark.asyncio
async def test_run_rejects_malicious_database(snowflake_tool, mock_snowflake_connection):
"""Verify that _run raises ValueError for SQL injection attempts in database."""
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
with pytest.raises(ValueError, match="Invalid database"):
await snowflake_tool._run(
query="SELECT 1",
database="test_db; DROP TABLE users; --",
)
@pytest.mark.asyncio
async def test_run_rejects_malicious_schema(snowflake_tool, mock_snowflake_connection):
"""Verify that _run raises ValueError for SQL injection attempts in schema."""
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
with pytest.raises(ValueError, match="Invalid schema"):
await snowflake_tool._run(
query="SELECT 1",
snowflake_schema="public; DROP TABLE users; --",
)

View File

@@ -21672,6 +21672,7 @@
"database": {
"anyOf": [
{
"pattern": "^[A-Za-z_][A-Za-z0-9_$]*$",
"type": "string"
},
{
@@ -21690,6 +21691,7 @@
"snowflake_schema": {
"anyOf": [
{
"pattern": "^[A-Za-z_][A-Za-z0-9_$]*$",
"type": "string"
},
{