mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 13:18:10 +00:00
fix mysql search table name validation
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -6,6 +7,26 @@ from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
_MYSQL_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$")
|
||||
|
||||
|
||||
def _quote_mysql_table_name(table_name: str) -> str:
|
||||
identifier_parts = table_name.split(".")
|
||||
if (
|
||||
not identifier_parts
|
||||
or len(identifier_parts) > 2
|
||||
or any(
|
||||
not _MYSQL_IDENTIFIER_PATTERN.fullmatch(part) for part in identifier_parts
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"MySQL table_name must be a valid table identifier or schema.table "
|
||||
"identifier"
|
||||
)
|
||||
|
||||
return ".".join(f"`{part}`" for part in identifier_parts)
|
||||
|
||||
|
||||
class MySQLSearchToolSchema(BaseModel):
|
||||
"""Input for MySQLSearchTool."""
|
||||
|
||||
@@ -32,7 +53,8 @@ class MySQLSearchTool(RagTool):
|
||||
table_name: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().add(f"SELECT * FROM {table_name};", **kwargs) # noqa: S608
|
||||
quoted_table_name = _quote_mysql_table_name(table_name)
|
||||
super().add(f"SELECT * FROM {quoted_table_name};", **kwargs) # noqa: S608
|
||||
|
||||
def _run( # type: ignore[override]
|
||||
self,
|
||||
|
||||
110
lib/crewai-tools/tests/tools/test_mysql_search_tool.py
Normal file
110
lib/crewai-tools/tests/tools/test_mysql_search_tool.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.rag.data_types import DataType
|
||||
from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_client() -> MagicMock:
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection = MagicMock(return_value=None)
|
||||
mock_client.add_documents = MagicMock(return_value=None)
|
||||
mock_client.search = MagicMock(return_value=[])
|
||||
return mock_client
|
||||
|
||||
|
||||
def create_mysql_search_tool(
|
||||
mock_rag_client: MagicMock, table_name: str
|
||||
) -> MySQLSearchTool:
|
||||
with (
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.create_client",
|
||||
return_value=mock_rag_client,
|
||||
),
|
||||
):
|
||||
return MySQLSearchTool(
|
||||
db_uri="mysql://user:password@localhost:3306/test_database",
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("table_name", "expected_query"),
|
||||
[
|
||||
("users", "SELECT * FROM `users`;"),
|
||||
("user_profiles_2026", "SELECT * FROM `user_profiles_2026`;"),
|
||||
("schema_name.users", "SELECT * FROM `schema_name`.`users`;"),
|
||||
("information_schema.tables", "SELECT * FROM `information_schema`.`tables`;"),
|
||||
],
|
||||
)
|
||||
def test_mysql_search_tool_quotes_valid_table_identifiers(
|
||||
mock_rag_client: MagicMock, table_name: str, expected_query: str
|
||||
) -> None:
|
||||
with patch.object(RagTool, "add", return_value=None) as mock_add:
|
||||
create_mysql_search_tool(mock_rag_client, table_name)
|
||||
|
||||
mock_add.assert_called_once_with(
|
||||
expected_query,
|
||||
data_type=DataType.MYSQL,
|
||||
metadata={"db_uri": "mysql://user:password@localhost:3306/test_database"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"table_name",
|
||||
[
|
||||
"users where 1=1",
|
||||
"users; drop table users;--",
|
||||
"users -- comment",
|
||||
"users/*comment*/",
|
||||
"`users`",
|
||||
"schema.users.extra",
|
||||
"schema.",
|
||||
".users",
|
||||
"123users",
|
||||
],
|
||||
)
|
||||
def test_mysql_search_tool_rejects_invalid_table_identifiers(
|
||||
mock_rag_client: MagicMock, table_name: str
|
||||
) -> None:
|
||||
with (
|
||||
patch.object(RagTool, "add", return_value=None) as mock_add,
|
||||
pytest.raises(ValueError, match="MySQL table_name must be a valid"),
|
||||
):
|
||||
create_mysql_search_tool(mock_rag_client, table_name)
|
||||
|
||||
mock_add.assert_not_called()
|
||||
|
||||
|
||||
def test_mysql_search_tool_still_runs_search_queries(
|
||||
mock_rag_client: MagicMock,
|
||||
) -> None:
|
||||
with patch.object(RagTool, "add", return_value=None):
|
||||
tool = create_mysql_search_tool(mock_rag_client, "users")
|
||||
|
||||
with patch.object(RagTool, "_run", return_value="Alice") as mock_run:
|
||||
result = tool._run("alice")
|
||||
|
||||
assert "Alice" in result
|
||||
mock_run.assert_called_once_with(
|
||||
query="alice", similarity_threshold=None, limit=None
|
||||
)
|
||||
|
||||
|
||||
def test_mysql_search_tool_uses_mysql_data_type_metadata(
|
||||
mock_rag_client: MagicMock,
|
||||
) -> None:
|
||||
with patch.object(RagTool, "add", return_value=None) as mock_add:
|
||||
create_mysql_search_tool(mock_rag_client, "users")
|
||||
|
||||
assert mock_add.call_args.kwargs == {
|
||||
"data_type": DataType.MYSQL,
|
||||
"metadata": {"db_uri": "mysql://user:password@localhost:3306/test_database"},
|
||||
}
|
||||
Reference in New Issue
Block a user