fix mysql search table name validation

This commit is contained in:
Rip&Tear
2026-06-26 10:43:35 +08:00
parent 5d4851eac7
commit e3f7d1a12e
2 changed files with 133 additions and 1 deletions

View File

@@ -1,3 +1,4 @@
import re
from typing import Any
from pydantic import BaseModel, Field
@@ -6,6 +7,26 @@ from crewai_tools.rag.data_types import DataType
from crewai_tools.tools.rag.rag_tool import RagTool
_MYSQL_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$")
def _quote_mysql_table_name(table_name: str) -> str:
identifier_parts = table_name.split(".")
if (
not identifier_parts
or len(identifier_parts) > 2
or any(
not _MYSQL_IDENTIFIER_PATTERN.fullmatch(part) for part in identifier_parts
)
):
raise ValueError(
"MySQL table_name must be a valid table identifier or schema.table "
"identifier"
)
return ".".join(f"`{part}`" for part in identifier_parts)
class MySQLSearchToolSchema(BaseModel):
"""Input for MySQLSearchTool."""
@@ -32,7 +53,8 @@ class MySQLSearchTool(RagTool):
table_name: str,
**kwargs: Any,
) -> None:
super().add(f"SELECT * FROM {table_name};", **kwargs) # noqa: S608
quoted_table_name = _quote_mysql_table_name(table_name)
super().add(f"SELECT * FROM {quoted_table_name};", **kwargs) # noqa: S608
def _run( # type: ignore[override]
self,

View File

@@ -0,0 +1,110 @@
from unittest.mock import MagicMock, patch
import pytest
from crewai_tools.rag.data_types import DataType
from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool
from crewai_tools.tools.rag.rag_tool import RagTool
@pytest.fixture
def mock_rag_client() -> MagicMock:
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_client.add_documents = MagicMock(return_value=None)
mock_client.search = MagicMock(return_value=[])
return mock_client
def create_mysql_search_tool(
mock_rag_client: MagicMock, table_name: str
) -> MySQLSearchTool:
with (
patch(
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
return_value=mock_rag_client,
),
patch(
"crewai_tools.adapters.crewai_rag_adapter.create_client",
return_value=mock_rag_client,
),
):
return MySQLSearchTool(
db_uri="mysql://user:password@localhost:3306/test_database",
table_name=table_name,
)
@pytest.mark.parametrize(
("table_name", "expected_query"),
[
("users", "SELECT * FROM `users`;"),
("user_profiles_2026", "SELECT * FROM `user_profiles_2026`;"),
("schema_name.users", "SELECT * FROM `schema_name`.`users`;"),
("information_schema.tables", "SELECT * FROM `information_schema`.`tables`;"),
],
)
def test_mysql_search_tool_quotes_valid_table_identifiers(
mock_rag_client: MagicMock, table_name: str, expected_query: str
) -> None:
with patch.object(RagTool, "add", return_value=None) as mock_add:
create_mysql_search_tool(mock_rag_client, table_name)
mock_add.assert_called_once_with(
expected_query,
data_type=DataType.MYSQL,
metadata={"db_uri": "mysql://user:password@localhost:3306/test_database"},
)
@pytest.mark.parametrize(
"table_name",
[
"users where 1=1",
"users; drop table users;--",
"users -- comment",
"users/*comment*/",
"`users`",
"schema.users.extra",
"schema.",
".users",
"123users",
],
)
def test_mysql_search_tool_rejects_invalid_table_identifiers(
mock_rag_client: MagicMock, table_name: str
) -> None:
with (
patch.object(RagTool, "add", return_value=None) as mock_add,
pytest.raises(ValueError, match="MySQL table_name must be a valid"),
):
create_mysql_search_tool(mock_rag_client, table_name)
mock_add.assert_not_called()
def test_mysql_search_tool_still_runs_search_queries(
mock_rag_client: MagicMock,
) -> None:
with patch.object(RagTool, "add", return_value=None):
tool = create_mysql_search_tool(mock_rag_client, "users")
with patch.object(RagTool, "_run", return_value="Alice") as mock_run:
result = tool._run("alice")
assert "Alice" in result
mock_run.assert_called_once_with(
query="alice", similarity_threshold=None, limit=None
)
def test_mysql_search_tool_uses_mysql_data_type_metadata(
mock_rag_client: MagicMock,
) -> None:
with patch.object(RagTool, "add", return_value=None) as mock_add:
create_mysql_search_tool(mock_rag_client, "users")
assert mock_add.call_args.kwargs == {
"data_type": DataType.MYSQL,
"metadata": {"db_uri": "mysql://user:password@localhost:3306/test_database"},
}