mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-22 11:48:16 +00:00
Compare commits
2 Commits
main
...
devin/1774
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
287ffe2f6d | ||
|
|
9a9cb48d09 |
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -52,7 +53,18 @@ class NL2SQLTool(BaseTool):
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_identifier(value: str, name: str) -> str:
|
||||
"""Validate a SQL identifier to prevent SQL injection."""
|
||||
if not re.match(r"^[A-Za-z_][A-Za-z0-9_$]*$", value):
|
||||
raise ValueError(
|
||||
f"Invalid {name}: {value!r}. "
|
||||
f"Only alphanumeric characters, underscores, and dollar signs are allowed."
|
||||
)
|
||||
return value
|
||||
|
||||
def _fetch_all_available_columns(self, table_name: str):
|
||||
table_name = self._validate_identifier(table_name, "table_name")
|
||||
return self.execute_sql(
|
||||
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -71,8 +72,16 @@ class SnowflakeSearchToolInput(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
query: str = Field(..., description="SQL query or semantic search query to execute")
|
||||
database: str | None = Field(None, description="Override default database")
|
||||
snowflake_schema: str | None = Field(None, description="Override default schema")
|
||||
database: str | None = Field(
|
||||
None,
|
||||
description="Override default database",
|
||||
pattern=r"^[A-Za-z_][A-Za-z0-9_$]*$",
|
||||
)
|
||||
snowflake_schema: str | None = Field(
|
||||
None,
|
||||
description="Override default schema",
|
||||
pattern=r"^[A-Za-z_][A-Za-z0-9_$]*$",
|
||||
)
|
||||
timeout: int | None = Field(300, description="Query timeout in seconds")
|
||||
|
||||
|
||||
@@ -247,6 +256,16 @@ class SnowflakeSearchTool(BaseTool):
|
||||
continue
|
||||
raise RuntimeError("Query failed after all retries")
|
||||
|
||||
@staticmethod
|
||||
def _validate_identifier(value: str, name: str) -> str:
|
||||
"""Validate and sanitize a Snowflake identifier to prevent SQL injection."""
|
||||
if not re.match(r"^[A-Za-z_][A-Za-z0-9_$]*$", value):
|
||||
raise ValueError(
|
||||
f"Invalid {name}: {value!r}. "
|
||||
f"Only alphanumeric characters, underscores, and dollar signs are allowed."
|
||||
)
|
||||
return value
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
query: str,
|
||||
@@ -259,9 +278,11 @@ class SnowflakeSearchTool(BaseTool):
|
||||
try:
|
||||
# Override database/schema if provided
|
||||
if database:
|
||||
await self._execute_query(f"USE DATABASE {database}")
|
||||
database = self._validate_identifier(database, "database")
|
||||
await self._execute_query(f'USE DATABASE "{database}"')
|
||||
if snowflake_schema:
|
||||
await self._execute_query(f"USE SCHEMA {snowflake_schema}")
|
||||
snowflake_schema = self._validate_identifier(snowflake_schema, "schema")
|
||||
await self._execute_query(f'USE SCHEMA "{snowflake_schema}"')
|
||||
|
||||
return await self._execute_query(query, timeout)
|
||||
except Exception as e:
|
||||
|
||||
72
lib/crewai-tools/tests/tools/nl2sql_tool_test.py
Normal file
72
lib/crewai-tools/tests/tools/nl2sql_tool_test.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool
|
||||
|
||||
|
||||
class TestNL2SQLToolValidateIdentifier:
|
||||
"""Tests for SQL injection prevention via identifier validation."""
|
||||
|
||||
def test_valid_identifiers(self):
|
||||
assert NL2SQLTool._validate_identifier("users", "table_name") == "users"
|
||||
assert NL2SQLTool._validate_identifier("MY_TABLE", "table_name") == "MY_TABLE"
|
||||
assert NL2SQLTool._validate_identifier("table$1", "table_name") == "table$1"
|
||||
assert NL2SQLTool._validate_identifier("_private", "table_name") == "_private"
|
||||
|
||||
def test_rejects_sql_injection_with_semicolon(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users; DROP TABLE users;--", "table_name")
|
||||
|
||||
def test_rejects_sql_injection_with_quotes(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users'--", "table_name")
|
||||
|
||||
def test_rejects_sql_injection_with_spaces(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users DROP TABLE", "table_name")
|
||||
|
||||
def test_rejects_leading_number(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("1table", "table_name")
|
||||
|
||||
def test_rejects_empty_string(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("", "table_name")
|
||||
|
||||
def test_rejects_parentheses(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users()", "table_name")
|
||||
|
||||
def test_rejects_dash_comment(self):
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
NL2SQLTool._validate_identifier("users--comment", "table_name")
|
||||
|
||||
|
||||
@patch("crewai_tools.tools.nl2sql.nl2sql_tool.SQLALCHEMY_AVAILABLE", True)
|
||||
class TestNL2SQLToolFetchColumns:
|
||||
"""Tests that _fetch_all_available_columns validates table names."""
|
||||
|
||||
def _make_tool(self):
|
||||
"""Create an NL2SQLTool instance bypassing model_post_init DB calls."""
|
||||
with patch.object(NL2SQLTool, "model_post_init"):
|
||||
tool = NL2SQLTool(
|
||||
db_uri="sqlite:///:memory:",
|
||||
name="NL2SQLTool",
|
||||
description="test",
|
||||
)
|
||||
return tool
|
||||
|
||||
def test_rejects_malicious_table_name(self):
|
||||
tool = self._make_tool()
|
||||
with pytest.raises(ValueError, match="Invalid table_name"):
|
||||
tool._fetch_all_available_columns("users'; DROP TABLE users;--")
|
||||
|
||||
def test_accepts_valid_table_name(self):
|
||||
tool = self._make_tool()
|
||||
with patch.object(NL2SQLTool, "execute_sql", return_value=[]) as mock_exec:
|
||||
result = tool._fetch_all_available_columns("valid_table")
|
||||
mock_exec.assert_called_once()
|
||||
call_sql = mock_exec.call_args[0][0]
|
||||
assert "valid_table" in call_sql
|
||||
assert result == []
|
||||
@@ -2,6 +2,9 @@ import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
|
||||
from crewai_tools.tools.snowflake_search_tool.snowflake_search_tool import (
|
||||
SnowflakeSearchToolInput,
|
||||
)
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -100,3 +103,136 @@ def test_config_validation():
|
||||
# Test missing authentication
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig(account="test_account", user="test_user")
|
||||
|
||||
|
||||
# SQL Injection Prevention Tests
|
||||
class TestSnowflakeSearchToolInputValidation:
|
||||
"""Tests for SQL injection prevention via input schema validation."""
|
||||
|
||||
def test_valid_database_identifier(self):
|
||||
inp = SnowflakeSearchToolInput(query="SELECT 1", database="my_database")
|
||||
assert inp.database == "my_database"
|
||||
|
||||
def test_valid_schema_identifier(self):
|
||||
inp = SnowflakeSearchToolInput(query="SELECT 1", snowflake_schema="public")
|
||||
assert inp.snowflake_schema == "public"
|
||||
|
||||
def test_valid_identifier_with_dollar_sign(self):
|
||||
inp = SnowflakeSearchToolInput(query="SELECT 1", database="my$db")
|
||||
assert inp.database == "my$db"
|
||||
|
||||
def test_database_with_sql_injection_semicolon(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(
|
||||
query="SELECT 1", database="test_db; DROP TABLE users; --"
|
||||
)
|
||||
|
||||
def test_schema_with_sql_injection_semicolon(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(
|
||||
query="SELECT 1", snowflake_schema="public; DROP TABLE users; --"
|
||||
)
|
||||
|
||||
def test_database_with_sql_injection_spaces(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(
|
||||
query="SELECT 1", database="test_db DROP TABLE"
|
||||
)
|
||||
|
||||
def test_schema_with_sql_injection_quotes(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(
|
||||
query="SELECT 1", snowflake_schema="public\"--"
|
||||
)
|
||||
|
||||
def test_database_with_sql_injection_dash_comment(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(
|
||||
query="SELECT 1", database="test--comment"
|
||||
)
|
||||
|
||||
def test_database_starting_with_number(self):
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeSearchToolInput(query="SELECT 1", database="1invalid")
|
||||
|
||||
def test_none_database_is_allowed(self):
|
||||
inp = SnowflakeSearchToolInput(query="SELECT 1", database=None)
|
||||
assert inp.database is None
|
||||
|
||||
def test_none_schema_is_allowed(self):
|
||||
inp = SnowflakeSearchToolInput(query="SELECT 1", snowflake_schema=None)
|
||||
assert inp.snowflake_schema is None
|
||||
|
||||
|
||||
class TestSnowflakeSearchToolValidateIdentifier:
|
||||
"""Tests for the _validate_identifier runtime check."""
|
||||
|
||||
def test_valid_identifiers(self):
|
||||
assert SnowflakeSearchTool._validate_identifier("my_db", "database") == "my_db"
|
||||
assert SnowflakeSearchTool._validate_identifier("PROD_DB", "database") == "PROD_DB"
|
||||
assert SnowflakeSearchTool._validate_identifier("schema$1", "schema") == "schema$1"
|
||||
assert SnowflakeSearchTool._validate_identifier("_private", "schema") == "_private"
|
||||
|
||||
def test_rejects_semicolons(self):
|
||||
with pytest.raises(ValueError, match="Invalid database"):
|
||||
SnowflakeSearchTool._validate_identifier("db; DROP TABLE users;--", "database")
|
||||
|
||||
def test_rejects_spaces(self):
|
||||
with pytest.raises(ValueError, match="Invalid schema"):
|
||||
SnowflakeSearchTool._validate_identifier("public schema", "schema")
|
||||
|
||||
def test_rejects_quotes(self):
|
||||
with pytest.raises(ValueError, match="Invalid database"):
|
||||
SnowflakeSearchTool._validate_identifier('db"--', "database")
|
||||
|
||||
def test_rejects_leading_number(self):
|
||||
with pytest.raises(ValueError, match="Invalid database"):
|
||||
SnowflakeSearchTool._validate_identifier("1db", "database")
|
||||
|
||||
def test_rejects_empty_string(self):
|
||||
with pytest.raises(ValueError, match="Invalid database"):
|
||||
SnowflakeSearchTool._validate_identifier("", "database")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_uses_quoted_identifiers(snowflake_tool, mock_snowflake_connection):
|
||||
"""Verify that _run wraps database/schema in double quotes in the SQL."""
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
await snowflake_tool._run(
|
||||
query="SELECT 1",
|
||||
database="my_db",
|
||||
snowflake_schema="my_schema",
|
||||
)
|
||||
|
||||
calls = mock_snowflake_connection.cursor().execute.call_args_list
|
||||
sql_statements = [call[0][0] for call in calls]
|
||||
assert 'USE DATABASE "my_db"' in sql_statements
|
||||
assert 'USE SCHEMA "my_schema"' in sql_statements
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_rejects_malicious_database(snowflake_tool, mock_snowflake_connection):
|
||||
"""Verify that _run raises ValueError for SQL injection attempts in database."""
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid database"):
|
||||
await snowflake_tool._run(
|
||||
query="SELECT 1",
|
||||
database="test_db; DROP TABLE users; --",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_rejects_malicious_schema(snowflake_tool, mock_snowflake_connection):
|
||||
"""Verify that _run raises ValueError for SQL injection attempts in schema."""
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid schema"):
|
||||
await snowflake_tool._run(
|
||||
query="SELECT 1",
|
||||
snowflake_schema="public; DROP TABLE users; --",
|
||||
)
|
||||
|
||||
@@ -21672,6 +21672,7 @@
|
||||
"database": {
|
||||
"anyOf": [
|
||||
{
|
||||
"pattern": "^[A-Za-z_][A-Za-z0-9_$]*$",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
@@ -21690,6 +21691,7 @@
|
||||
"snowflake_schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"pattern": "^[A-Za-z_][A-Za-z0-9_$]*$",
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user