diff --git a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py index 3ddea477b..e4a61d2ca 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py @@ -1,3 +1,4 @@ +import re from typing import Any from crewai.tools import BaseTool @@ -52,7 +53,18 @@ class NL2SQLTool(BaseTool): "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';" ) + @staticmethod + def _validate_identifier(value: str, name: str) -> str: + """Validate a SQL identifier to prevent SQL injection.""" + if not re.match(r"^[A-Za-z_][A-Za-z0-9_$]*$", value): + raise ValueError( + f"Invalid {name}: {value!r}. " + f"Only alphanumeric characters, underscores, and dollar signs are allowed." + ) + return value + def _fetch_all_available_columns(self, table_name: str): + table_name = self._validate_identifier(table_name, "table_name") return self.execute_sql( f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608 ) diff --git a/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py index c54209276..d1f999dc0 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from concurrent.futures import ThreadPoolExecutor import logging +import re import threading from typing import TYPE_CHECKING, Any @@ -71,8 +72,16 @@ class SnowflakeSearchToolInput(BaseModel): model_config = ConfigDict(protected_namespaces=()) query: str = Field(..., description="SQL query or semantic search query to execute") - database: str | None = Field(None, description="Override default database") - snowflake_schema: str | None = Field(None, description="Override default schema") + database: str | None = Field( + None, + description="Override default database", + pattern=r"^[A-Za-z_][A-Za-z0-9_$]*$", + ) + snowflake_schema: str | None = Field( + None, + description="Override default schema", + pattern=r"^[A-Za-z_][A-Za-z0-9_$]*$", + ) timeout: int | None = Field(300, description="Query timeout in seconds") @@ -247,6 +256,16 @@ class SnowflakeSearchTool(BaseTool): continue raise RuntimeError("Query failed after all retries") + @staticmethod + def _validate_identifier(value: str, name: str) -> str: + """Validate and sanitize a Snowflake identifier to prevent SQL injection.""" + if not re.match(r"^[A-Za-z_][A-Za-z0-9_$]*$", value): + raise ValueError( + f"Invalid {name}: {value!r}. " + f"Only alphanumeric characters, underscores, and dollar signs are allowed." + ) + return value + async def _run( self, query: str, @@ -259,9 +278,11 @@ class SnowflakeSearchTool(BaseTool): try: # Override database/schema if provided if database: - await self._execute_query(f"USE DATABASE {database}") + database = self._validate_identifier(database, "database") + await self._execute_query(f'USE DATABASE "{database}"') if snowflake_schema: - await self._execute_query(f"USE SCHEMA {snowflake_schema}") + snowflake_schema = self._validate_identifier(snowflake_schema, "schema") + await self._execute_query(f'USE SCHEMA "{snowflake_schema}"') return await self._execute_query(query, timeout) except Exception as e: diff --git a/lib/crewai-tools/tests/tools/nl2sql_tool_test.py b/lib/crewai-tools/tests/tools/nl2sql_tool_test.py new file mode 100644 index 000000000..b1fd9235e --- /dev/null +++ b/lib/crewai-tools/tests/tools/nl2sql_tool_test.py @@ -0,0 +1,72 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from crewai_tools.tools.nl2sql.nl2sql_tool import NL2SQLTool + + +class TestNL2SQLToolValidateIdentifier: + """Tests for SQL injection prevention via identifier validation.""" + + def test_valid_identifiers(self): + assert NL2SQLTool._validate_identifier("users", "table_name") == "users" + assert NL2SQLTool._validate_identifier("MY_TABLE", "table_name") == "MY_TABLE" + assert NL2SQLTool._validate_identifier("table$1", "table_name") == "table$1" + assert NL2SQLTool._validate_identifier("_private", "table_name") == "_private" + + def test_rejects_sql_injection_with_semicolon(self): + with pytest.raises(ValueError, match="Invalid table_name"): + NL2SQLTool._validate_identifier("users; DROP TABLE users;--", "table_name") + + def test_rejects_sql_injection_with_quotes(self): + with pytest.raises(ValueError, match="Invalid table_name"): + NL2SQLTool._validate_identifier("users'--", "table_name") + + def test_rejects_sql_injection_with_spaces(self): + with pytest.raises(ValueError, match="Invalid table_name"): + NL2SQLTool._validate_identifier("users DROP TABLE", "table_name") + + def test_rejects_leading_number(self): + with pytest.raises(ValueError, match="Invalid table_name"): + NL2SQLTool._validate_identifier("1table", "table_name") + + def test_rejects_empty_string(self): + with pytest.raises(ValueError, match="Invalid table_name"): + NL2SQLTool._validate_identifier("", "table_name") + + def test_rejects_parentheses(self): + with pytest.raises(ValueError, match="Invalid table_name"): + NL2SQLTool._validate_identifier("users()", "table_name") + + def test_rejects_dash_comment(self): + with pytest.raises(ValueError, match="Invalid table_name"): + NL2SQLTool._validate_identifier("users--comment", "table_name") + + +@patch("crewai_tools.tools.nl2sql.nl2sql_tool.SQLALCHEMY_AVAILABLE", True) +class TestNL2SQLToolFetchColumns: + """Tests that _fetch_all_available_columns validates table names.""" + + def _make_tool(self): + """Create an NL2SQLTool instance bypassing model_post_init DB calls.""" + with patch.object(NL2SQLTool, "model_post_init"): + tool = NL2SQLTool( + db_uri="sqlite:///:memory:", + name="NL2SQLTool", + description="test", + ) + return tool + + def test_rejects_malicious_table_name(self): + tool = self._make_tool() + with pytest.raises(ValueError, match="Invalid table_name"): + tool._fetch_all_available_columns("users'; DROP TABLE users;--") + + def test_accepts_valid_table_name(self): + tool = self._make_tool() + with patch.object(NL2SQLTool, "execute_sql", return_value=[]) as mock_exec: + result = tool._fetch_all_available_columns("valid_table") + mock_exec.assert_called_once() + call_sql = mock_exec.call_args[0][0] + assert "valid_table" in call_sql + assert result == [] diff --git a/lib/crewai-tools/tests/tools/snowflake_search_tool_test.py b/lib/crewai-tools/tests/tools/snowflake_search_tool_test.py index fe827d5df..82945bc54 100644 --- a/lib/crewai-tools/tests/tools/snowflake_search_tool_test.py +++ b/lib/crewai-tools/tests/tools/snowflake_search_tool_test.py @@ -2,6 +2,9 @@ import asyncio from unittest.mock import MagicMock, patch from crewai_tools import SnowflakeConfig, SnowflakeSearchTool +from crewai_tools.tools.snowflake_search_tool.snowflake_search_tool import ( + SnowflakeSearchToolInput, +) import pytest @@ -100,3 +103,136 @@ def test_config_validation(): # Test missing authentication with pytest.raises(ValueError): SnowflakeConfig(account="test_account", user="test_user") + + +# SQL Injection Prevention Tests +class TestSnowflakeSearchToolInputValidation: + """Tests for SQL injection prevention via input schema validation.""" + + def test_valid_database_identifier(self): + inp = SnowflakeSearchToolInput(query="SELECT 1", database="my_database") + assert inp.database == "my_database" + + def test_valid_schema_identifier(self): + inp = SnowflakeSearchToolInput(query="SELECT 1", snowflake_schema="public") + assert inp.snowflake_schema == "public" + + def test_valid_identifier_with_dollar_sign(self): + inp = SnowflakeSearchToolInput(query="SELECT 1", database="my$db") + assert inp.database == "my$db" + + def test_database_with_sql_injection_semicolon(self): + with pytest.raises(ValueError): + SnowflakeSearchToolInput( + query="SELECT 1", database="test_db; DROP TABLE users; --" + ) + + def test_schema_with_sql_injection_semicolon(self): + with pytest.raises(ValueError): + SnowflakeSearchToolInput( + query="SELECT 1", snowflake_schema="public; DROP TABLE users; --" + ) + + def test_database_with_sql_injection_spaces(self): + with pytest.raises(ValueError): + SnowflakeSearchToolInput( + query="SELECT 1", database="test_db DROP TABLE" + ) + + def test_schema_with_sql_injection_quotes(self): + with pytest.raises(ValueError): + SnowflakeSearchToolInput( + query="SELECT 1", snowflake_schema="public\"--" + ) + + def test_database_with_sql_injection_dash_comment(self): + with pytest.raises(ValueError): + SnowflakeSearchToolInput( + query="SELECT 1", database="test--comment" + ) + + def test_database_starting_with_number(self): + with pytest.raises(ValueError): + SnowflakeSearchToolInput(query="SELECT 1", database="1invalid") + + def test_none_database_is_allowed(self): + inp = SnowflakeSearchToolInput(query="SELECT 1", database=None) + assert inp.database is None + + def test_none_schema_is_allowed(self): + inp = SnowflakeSearchToolInput(query="SELECT 1", snowflake_schema=None) + assert inp.snowflake_schema is None + + +class TestSnowflakeSearchToolValidateIdentifier: + """Tests for the _validate_identifier runtime check.""" + + def test_valid_identifiers(self): + assert SnowflakeSearchTool._validate_identifier("my_db", "database") == "my_db" + assert SnowflakeSearchTool._validate_identifier("PROD_DB", "database") == "PROD_DB" + assert SnowflakeSearchTool._validate_identifier("schema$1", "schema") == "schema$1" + assert SnowflakeSearchTool._validate_identifier("_private", "schema") == "_private" + + def test_rejects_semicolons(self): + with pytest.raises(ValueError, match="Invalid database"): + SnowflakeSearchTool._validate_identifier("db; DROP TABLE users;--", "database") + + def test_rejects_spaces(self): + with pytest.raises(ValueError, match="Invalid schema"): + SnowflakeSearchTool._validate_identifier("public schema", "schema") + + def test_rejects_quotes(self): + with pytest.raises(ValueError, match="Invalid database"): + SnowflakeSearchTool._validate_identifier('db"--', "database") + + def test_rejects_leading_number(self): + with pytest.raises(ValueError, match="Invalid database"): + SnowflakeSearchTool._validate_identifier("1db", "database") + + def test_rejects_empty_string(self): + with pytest.raises(ValueError, match="Invalid database"): + SnowflakeSearchTool._validate_identifier("", "database") + + +@pytest.mark.asyncio +async def test_run_uses_quoted_identifiers(snowflake_tool, mock_snowflake_connection): + """Verify that _run wraps database/schema in double quotes in the SQL.""" + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + await snowflake_tool._run( + query="SELECT 1", + database="my_db", + snowflake_schema="my_schema", + ) + + calls = mock_snowflake_connection.cursor().execute.call_args_list + sql_statements = [call[0][0] for call in calls] + assert 'USE DATABASE "my_db"' in sql_statements + assert 'USE SCHEMA "my_schema"' in sql_statements + + +@pytest.mark.asyncio +async def test_run_rejects_malicious_database(snowflake_tool, mock_snowflake_connection): + """Verify that _run raises ValueError for SQL injection attempts in database.""" + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + with pytest.raises(ValueError, match="Invalid database"): + await snowflake_tool._run( + query="SELECT 1", + database="test_db; DROP TABLE users; --", + ) + + +@pytest.mark.asyncio +async def test_run_rejects_malicious_schema(snowflake_tool, mock_snowflake_connection): + """Verify that _run raises ValueError for SQL injection attempts in schema.""" + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + with pytest.raises(ValueError, match="Invalid schema"): + await snowflake_tool._run( + query="SELECT 1", + snowflake_schema="public; DROP TABLE users; --", + )