mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: harden NL2SQLTool — read-only default, query validation, parameterized queries (#5311)
Some checks failed
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Some checks failed
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Check Documentation Broken Links / Check broken links (push) Has been cancelled
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
* fix: harden NL2SQLTool — read-only by default, parameterized queries, query validation Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: address CI lint failures and remove unused import - Remove unused `sessionmaker` import from test_nl2sql_security.py - Use `Self` return type on `_apply_env_override` (fixes UP037/F821) - Fix ruff errors auto-fixed in lib/crewai (UP007, etc.) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: expand _WRITE_COMMANDS and block multi-statement semicolon injection - Add missing write commands: UPSERT, LOAD, COPY, VACUUM, ANALYZE, ANALYSE, REINDEX, CLUSTER, REFRESH, COMMENT, SET, RESET - _validate_query() now splits on ';' and validates each statement independently; multi-statement queries are rejected outright in read-only mode to prevent 'SELECT 1; DROP TABLE users' bypass - Extract single-statement logic into _validate_statement() helper - Add TestSemicolonInjection and TestExtendedWriteCommands test classes Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * ci: retrigger * fix: use typing_extensions.Self for Python 3.10 compat * chore: update tool specifications * docs: document NL2SQLTool read-only default and DML configuration Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: close three NL2SQLTool security gaps (writable CTEs, EXPLAIN ANALYZE, multi-stmt commit) - Remove WITH from _READ_ONLY_COMMANDS; scan CTE body for write keywords so writable CTEs like `WITH d AS (DELETE …) SELECT …` are blocked in read-only mode. - EXPLAIN ANALYZE/ANALYSE now resolves the underlying command; EXPLAIN ANALYZE DELETE is treated as a write and blocked in read-only mode. - execute_sql commit decision now checks ALL semicolon-separated statements so a SELECT-first batch like `SELECT 1; DROP TABLE t` still triggers a commit when allow_dml=True. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix: handle parenthesized EXPLAIN options syntax; remove unused _seed_db _validate_statement now strips parenthesized options from EXPLAIN (e.g. EXPLAIN (ANALYZE) DELETE, EXPLAIN (ANALYZE, VERBOSE) DELETE) before checking whether ANALYZE/ANALYSE is present — closing the bypass where the options-list form was silently allowed in read-only mode. Adds three new tests: - EXPLAIN (ANALYZE) DELETE → blocked - EXPLAIN (ANALYZE, VERBOSE) DELETE → blocked - EXPLAIN (VERBOSE) SELECT → allowed Also removes the unused _seed_db helper from test_nl2sql_security.py. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * chore: update tool specifications * fix: smarter CTE write detection, fix commit logic for writable CTEs - Replace naive token-set matching with positional AS() body inspection to avoid false positives on column names like 'comment', 'set', 'reset' - Fix execute_sql commit logic to detect writable CTEs (WITH + DELETE/INSERT) not just top-level write commands - Add tests for false positive cases and writable CTE commit behavior - Format nl2sql_tool.py to pass ruff format check * fix: catch write commands in CTE main query + handle whitespace in AS() - WITH cte AS (SELECT 1) DELETE FROM users now correctly blocked - AS followed by newline/tab/multi-space before ( now detected - execute_sql commit logic updated for both cases - 4 new tests * fix: EXPLAIN ANALYZE VERBOSE handling, string literal paren bypass, commit logic for EXPLAIN ANALYZE - EXPLAIN handler now consumes all known options (ANALYZE, ANALYSE, VERBOSE) before extracting the real command, fixing 'EXPLAIN ANALYZE VERBOSE SELECT' being blocked - Paren walker in _extract_main_query_after_cte now skips string literals, preventing 'WITH cte AS (SELECT '\''('\'' FROM t) DELETE FROM users' from bypassing detection - _is_write_stmt in execute_sql now resolves EXPLAIN ANALYZE to underlying command via _resolve_explain_command, ensuring session.commit() fires for write operations - 10 new tests covering all three fixes * fix: deduplicate EXPLAIN parsing, fix AS( regex in strings, block unknown CTE commands, bump langchain-core - Refactor _validate_statement to use _resolve_explain_command (single source of truth) - _iter_as_paren_matches skips string literals so 'AS (' in data doesn't confuse CTE detection - Unknown commands after CTE definitions now blocked in read-only mode - Bump langchain-core override to >=1.2.28 (GHSA-926x-3r5x-gfhw) * fix: add return type annotation to _iter_as_paren_matches --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,17 @@
|
||||
from collections.abc import Iterator
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
try:
|
||||
@@ -12,6 +22,186 @@ try:
|
||||
except ImportError:
|
||||
SQLALCHEMY_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Commands allowed in read-only mode
|
||||
# NOTE: WITH is intentionally excluded — writable CTEs start with WITH, so the
|
||||
# CTE body must be inspected separately (see _validate_statement).
|
||||
_READ_ONLY_COMMANDS = {"SELECT", "SHOW", "DESCRIBE", "DESC", "EXPLAIN"}
|
||||
|
||||
# Commands that mutate state and are blocked by default
|
||||
_WRITE_COMMANDS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"EXEC",
|
||||
"EXECUTE",
|
||||
"CALL",
|
||||
"MERGE",
|
||||
"REPLACE",
|
||||
"UPSERT",
|
||||
"LOAD",
|
||||
"COPY",
|
||||
"VACUUM",
|
||||
"ANALYZE",
|
||||
"ANALYSE",
|
||||
"REINDEX",
|
||||
"CLUSTER",
|
||||
"REFRESH",
|
||||
"COMMENT",
|
||||
"SET",
|
||||
"RESET",
|
||||
}
|
||||
|
||||
|
||||
# Subset of write commands that can realistically appear *inside* a CTE body.
|
||||
# Narrower than _WRITE_COMMANDS to avoid false positives on identifiers like
|
||||
# ``comment``, ``set``, or ``reset`` which are common column/table names.
|
||||
_CTE_WRITE_INDICATORS = {
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"DELETE",
|
||||
"DROP",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"TRUNCATE",
|
||||
"MERGE",
|
||||
}
|
||||
|
||||
|
||||
_AS_PAREN_RE = re.compile(r"\bAS\s*\(", re.IGNORECASE)
|
||||
|
||||
|
||||
def _iter_as_paren_matches(stmt: str) -> Iterator[re.Match[str]]:
|
||||
"""Yield regex matches for ``AS\\s*(`` outside of string literals."""
|
||||
# Build a set of character positions that are inside string literals.
|
||||
in_string: set[int] = set()
|
||||
i = 0
|
||||
while i < len(stmt):
|
||||
if stmt[i] == "'":
|
||||
start = i
|
||||
end = _skip_string_literal(stmt, i)
|
||||
in_string.update(range(start, end))
|
||||
i = end
|
||||
else:
|
||||
i += 1
|
||||
|
||||
for m in _AS_PAREN_RE.finditer(stmt):
|
||||
if m.start() not in in_string:
|
||||
yield m
|
||||
|
||||
|
||||
def _detect_writable_cte(stmt: str) -> str | None:
|
||||
"""Return the first write command inside a CTE body, or None.
|
||||
|
||||
Instead of tokenizing the whole statement (which falsely matches column
|
||||
names like ``comment``), this walks through parenthesized CTE bodies and
|
||||
checks only the *first keyword after* an opening ``AS (`` for a write
|
||||
command. Uses a regex to handle any whitespace (spaces, tabs, newlines)
|
||||
between ``AS`` and ``(``. Skips matches inside string literals.
|
||||
"""
|
||||
for m in _iter_as_paren_matches(stmt):
|
||||
body = stmt[m.end() :].lstrip()
|
||||
first_word = body.split()[0].upper().strip("()") if body.split() else ""
|
||||
if first_word in _CTE_WRITE_INDICATORS:
|
||||
return first_word
|
||||
return None
|
||||
|
||||
|
||||
def _skip_string_literal(stmt: str, pos: int) -> int:
|
||||
"""Skip past a string literal starting at pos (single-quoted).
|
||||
|
||||
Handles escaped quotes ('') inside the literal.
|
||||
Returns the index after the closing quote.
|
||||
"""
|
||||
quote_char = stmt[pos]
|
||||
i = pos + 1
|
||||
while i < len(stmt):
|
||||
if stmt[i] == quote_char:
|
||||
# Check for escaped quote ('')
|
||||
if i + 1 < len(stmt) and stmt[i + 1] == quote_char:
|
||||
i += 2
|
||||
continue
|
||||
return i + 1
|
||||
i += 1
|
||||
return i # Unterminated literal — return end
|
||||
|
||||
|
||||
def _find_matching_close_paren(stmt: str, start: int) -> int:
|
||||
"""Find the matching close paren, skipping string literals."""
|
||||
depth = 1
|
||||
i = start
|
||||
while i < len(stmt) and depth > 0:
|
||||
ch = stmt[i]
|
||||
if ch == "'":
|
||||
i = _skip_string_literal(stmt, i)
|
||||
continue
|
||||
if ch == "(":
|
||||
depth += 1
|
||||
elif ch == ")":
|
||||
depth -= 1
|
||||
i += 1
|
||||
return i
|
||||
|
||||
|
||||
def _extract_main_query_after_cte(stmt: str) -> str | None:
|
||||
"""Extract the main (outer) query that follows all CTE definitions.
|
||||
|
||||
For ``WITH cte AS (SELECT 1) DELETE FROM users``, returns ``DELETE FROM users``.
|
||||
Returns None if no main query is found after the last CTE body.
|
||||
Handles parentheses inside string literals (e.g., ``SELECT '(' FROM t``).
|
||||
"""
|
||||
last_cte_end = 0
|
||||
for m in _iter_as_paren_matches(stmt):
|
||||
last_cte_end = _find_matching_close_paren(stmt, m.end())
|
||||
|
||||
if last_cte_end > 0:
|
||||
remainder = stmt[last_cte_end:].strip().lstrip(",").strip()
|
||||
if remainder:
|
||||
return remainder
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_explain_command(stmt: str) -> str | None:
|
||||
"""Resolve the underlying command from an EXPLAIN [ANALYZE] [VERBOSE] statement.
|
||||
|
||||
Returns the real command (e.g., 'DELETE') if ANALYZE is present, else None.
|
||||
Handles both space-separated and parenthesized syntax.
|
||||
"""
|
||||
rest = stmt.strip()[len("EXPLAIN") :].strip()
|
||||
if not rest:
|
||||
return None
|
||||
|
||||
analyze_found = False
|
||||
explain_opts = {"ANALYZE", "ANALYSE", "VERBOSE"}
|
||||
|
||||
if rest.startswith("("):
|
||||
close = rest.find(")")
|
||||
if close != -1:
|
||||
options_str = rest[1:close].upper()
|
||||
analyze_found = any(
|
||||
opt.strip() in ("ANALYZE", "ANALYSE") for opt in options_str.split(",")
|
||||
)
|
||||
rest = rest[close + 1 :].strip()
|
||||
else:
|
||||
while rest:
|
||||
first_opt = rest.split()[0].upper().rstrip(";") if rest.split() else ""
|
||||
if first_opt in ("ANALYZE", "ANALYSE"):
|
||||
analyze_found = True
|
||||
if first_opt not in explain_opts:
|
||||
break
|
||||
rest = rest[len(first_opt) :].strip()
|
||||
|
||||
if analyze_found and rest:
|
||||
return rest.split()[0].upper().rstrip(";")
|
||||
return None
|
||||
|
||||
|
||||
class NL2SQLToolInput(BaseModel):
|
||||
sql_query: str = Field(
|
||||
@@ -21,20 +211,70 @@ class NL2SQLToolInput(BaseModel):
|
||||
|
||||
|
||||
class NL2SQLTool(BaseTool):
|
||||
"""Tool that converts natural language to SQL and executes it against a database.
|
||||
|
||||
By default the tool operates in **read-only mode**: only SELECT, SHOW,
|
||||
DESCRIBE, EXPLAIN, and read-only CTEs (WITH … SELECT) are permitted. Write
|
||||
operations (INSERT, UPDATE, DELETE, DROP, ALTER, CREATE, TRUNCATE, …) are
|
||||
blocked unless ``allow_dml=True`` is set explicitly or the environment
|
||||
variable ``CREWAI_NL2SQL_ALLOW_DML=true`` is present.
|
||||
|
||||
Writable CTEs (``WITH d AS (DELETE …) SELECT …``) and
|
||||
``EXPLAIN ANALYZE <write-stmt>`` are treated as write operations and are
|
||||
blocked in read-only mode.
|
||||
|
||||
The ``_fetch_all_available_columns`` helper uses parameterised queries so
|
||||
that table names coming from the database catalogue cannot be used as an
|
||||
injection vector.
|
||||
"""
|
||||
|
||||
name: str = "NL2SQLTool"
|
||||
description: str = "Converts natural language to SQL queries and executes them."
|
||||
description: str = (
|
||||
"Converts natural language to SQL queries and executes them against a "
|
||||
"database. Read-only by default — only SELECT/SHOW/DESCRIBE/EXPLAIN "
|
||||
"queries (and read-only CTEs) are allowed unless configured with "
|
||||
"allow_dml=True."
|
||||
)
|
||||
db_uri: str = Field(
|
||||
title="Database URI",
|
||||
description="The URI of the database to connect to.",
|
||||
)
|
||||
allow_dml: bool = Field(
|
||||
default=False,
|
||||
title="Allow DML",
|
||||
description=(
|
||||
"When False (default) only read statements are permitted. "
|
||||
"Set to True to allow INSERT/UPDATE/DELETE/DROP and other "
|
||||
"write operations."
|
||||
),
|
||||
)
|
||||
tables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
|
||||
args_schema: type[BaseModel] = NL2SQLToolInput
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _apply_env_override(self) -> Self:
|
||||
"""Allow CREWAI_NL2SQL_ALLOW_DML=true to override allow_dml at runtime."""
|
||||
if os.environ.get("CREWAI_NL2SQL_ALLOW_DML", "").strip().lower() == "true":
|
||||
if not self.allow_dml:
|
||||
logger.warning(
|
||||
"NL2SQLTool: CREWAI_NL2SQL_ALLOW_DML env var is set — "
|
||||
"DML/DDL operations are enabled. Ensure this is intentional."
|
||||
)
|
||||
self.allow_dml = True
|
||||
return self
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if not SQLALCHEMY_AVAILABLE:
|
||||
raise ImportError(
|
||||
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
||||
"sqlalchemy is not installed. Please install it with "
|
||||
"`pip install crewai-tools[sqlalchemy]`"
|
||||
)
|
||||
|
||||
if self.allow_dml:
|
||||
logger.warning(
|
||||
"NL2SQLTool: allow_dml=True — write operations (INSERT/UPDATE/"
|
||||
"DELETE/DROP/…) are permitted. Use with caution."
|
||||
)
|
||||
|
||||
data: dict[str, list[dict[str, Any]] | str] = {}
|
||||
@@ -50,42 +290,216 @@ class NL2SQLTool(BaseTool):
|
||||
self.tables = tables
|
||||
self.columns = data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_query(self, sql_query: str) -> None:
|
||||
"""Raise ValueError if *sql_query* is not permitted under the current config.
|
||||
|
||||
Splits the query on semicolons and validates each statement
|
||||
independently. When ``allow_dml=False`` (the default), multi-statement
|
||||
queries are rejected outright to prevent ``SELECT 1; DROP TABLE users``
|
||||
style bypasses. When ``allow_dml=True`` every statement is checked and
|
||||
a warning is emitted for write operations.
|
||||
"""
|
||||
statements = [s.strip() for s in sql_query.split(";") if s.strip()]
|
||||
|
||||
if not statements:
|
||||
raise ValueError("NL2SQLTool received an empty SQL query.")
|
||||
|
||||
if not self.allow_dml and len(statements) > 1:
|
||||
raise ValueError(
|
||||
"NL2SQLTool blocked a multi-statement query in read-only mode. "
|
||||
"Semicolons are not permitted when allow_dml=False."
|
||||
)
|
||||
|
||||
for stmt in statements:
|
||||
self._validate_statement(stmt)
|
||||
|
||||
def _validate_statement(self, stmt: str) -> None:
|
||||
"""Validate a single SQL statement (no semicolons)."""
|
||||
command = self._extract_command(stmt)
|
||||
|
||||
# EXPLAIN ANALYZE / EXPLAIN ANALYSE actually *executes* the underlying
|
||||
# query. Resolve the real command so write operations are caught.
|
||||
# Handles both space-separated ("EXPLAIN ANALYZE DELETE …") and
|
||||
# parenthesized ("EXPLAIN (ANALYZE) DELETE …", "EXPLAIN (ANALYZE, VERBOSE) DELETE …").
|
||||
# EXPLAIN ANALYZE actually executes the underlying query — resolve the
|
||||
# real command so write operations are caught.
|
||||
if command == "EXPLAIN":
|
||||
resolved = _resolve_explain_command(stmt)
|
||||
if resolved:
|
||||
command = resolved
|
||||
|
||||
# WITH starts a CTE. Read-only CTEs are fine; writable CTEs
|
||||
# (e.g. WITH d AS (DELETE …) SELECT …) must be blocked in read-only mode.
|
||||
if command == "WITH":
|
||||
# Check for write commands inside CTE bodies.
|
||||
write_found = _detect_writable_cte(stmt)
|
||||
if write_found:
|
||||
found = write_found
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool is configured in read-only mode and blocked a "
|
||||
f"writable CTE containing a '{found}' statement. To allow "
|
||||
f"write operations set allow_dml=True or "
|
||||
f"CREWAI_NL2SQL_ALLOW_DML=true."
|
||||
)
|
||||
logger.warning(
|
||||
"NL2SQLTool: executing writable CTE with '%s' because allow_dml=True.",
|
||||
found,
|
||||
)
|
||||
return
|
||||
|
||||
# Check the main query after the CTE definitions.
|
||||
main_query = _extract_main_query_after_cte(stmt)
|
||||
if main_query:
|
||||
main_cmd = main_query.split()[0].upper().rstrip(";")
|
||||
if main_cmd in _WRITE_COMMANDS:
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool is configured in read-only mode and blocked a "
|
||||
f"'{main_cmd}' statement after a CTE. To allow write "
|
||||
f"operations set allow_dml=True or "
|
||||
f"CREWAI_NL2SQL_ALLOW_DML=true."
|
||||
)
|
||||
logger.warning(
|
||||
"NL2SQLTool: executing '%s' after CTE because allow_dml=True.",
|
||||
main_cmd,
|
||||
)
|
||||
elif main_cmd not in _READ_ONLY_COMMANDS:
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool blocked an unrecognised SQL command '{main_cmd}' "
|
||||
f"after a CTE. Only {sorted(_READ_ONLY_COMMANDS)} are allowed "
|
||||
f"in read-only mode."
|
||||
)
|
||||
return
|
||||
|
||||
if command in _WRITE_COMMANDS:
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool is configured in read-only mode and blocked a "
|
||||
f"'{command}' statement. To allow write operations set "
|
||||
f"allow_dml=True or CREWAI_NL2SQL_ALLOW_DML=true."
|
||||
)
|
||||
logger.warning(
|
||||
"NL2SQLTool: executing write statement '%s' because allow_dml=True.",
|
||||
command,
|
||||
)
|
||||
elif command not in _READ_ONLY_COMMANDS:
|
||||
# Unknown command — block by default unless DML is explicitly enabled
|
||||
if not self.allow_dml:
|
||||
raise ValueError(
|
||||
f"NL2SQLTool blocked an unrecognised SQL command '{command}'. "
|
||||
f"Only {sorted(_READ_ONLY_COMMANDS)} are allowed in read-only "
|
||||
f"mode."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_command(sql_query: str) -> str:
|
||||
"""Return the uppercased first keyword of *sql_query*."""
|
||||
stripped = sql_query.strip().lstrip("(")
|
||||
first_token = stripped.split()[0] if stripped.split() else ""
|
||||
return first_token.upper().rstrip(";")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Schema introspection helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _fetch_available_tables(self) -> list[dict[str, Any]] | str:
|
||||
return self.execute_sql(
|
||||
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';"
|
||||
"SELECT table_name FROM information_schema.tables "
|
||||
"WHERE table_schema = 'public';"
|
||||
)
|
||||
|
||||
def _fetch_all_available_columns(
|
||||
self, table_name: str
|
||||
) -> list[dict[str, Any]] | str:
|
||||
"""Fetch columns for *table_name* using a parameterised query.
|
||||
|
||||
The table name is bound via SQLAlchemy's ``:param`` syntax to prevent
|
||||
SQL injection from catalogue values.
|
||||
"""
|
||||
return self.execute_sql(
|
||||
f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';" # noqa: S608
|
||||
"SELECT column_name, data_type FROM information_schema.columns "
|
||||
"WHERE table_name = :table_name",
|
||||
params={"table_name": table_name},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
try:
|
||||
self._validate_query(sql_query)
|
||||
data = self.execute_sql(sql_query)
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
data = (
|
||||
f"Based on these tables {self.tables} and columns {self.columns}, "
|
||||
"you can create SQL queries to retrieve data from the database."
|
||||
f"Get the original request {sql_query} and the error {exc} and create the correct SQL query."
|
||||
"you can create SQL queries to retrieve data from the database. "
|
||||
f"Get the original request {sql_query} and the error {exc} and "
|
||||
"create the correct SQL query."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def execute_sql(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
def execute_sql(
|
||||
self,
|
||||
sql_query: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
) -> list[dict[str, Any]] | str:
|
||||
"""Execute *sql_query* and return the results as a list of dicts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sql_query:
|
||||
The SQL statement to run.
|
||||
params:
|
||||
Optional mapping of bind parameters (e.g. ``{"table_name": "users"}``).
|
||||
"""
|
||||
if not SQLALCHEMY_AVAILABLE:
|
||||
raise ImportError(
|
||||
"sqlalchemy is not installed. Please install it with `pip install crewai-tools[sqlalchemy]`"
|
||||
"sqlalchemy is not installed. Please install it with "
|
||||
"`pip install crewai-tools[sqlalchemy]`"
|
||||
)
|
||||
|
||||
# Check ALL statements so that e.g. "SELECT 1; DROP TABLE t" triggers a
|
||||
# commit when allow_dml=True, regardless of statement order.
|
||||
_stmts = [s.strip() for s in sql_query.split(";") if s.strip()]
|
||||
|
||||
def _is_write_stmt(s: str) -> bool:
|
||||
cmd = self._extract_command(s)
|
||||
if cmd in _WRITE_COMMANDS:
|
||||
return True
|
||||
if cmd == "EXPLAIN":
|
||||
# Resolve the underlying command for EXPLAIN ANALYZE
|
||||
resolved = _resolve_explain_command(s)
|
||||
if resolved and resolved in _WRITE_COMMANDS:
|
||||
return True
|
||||
if cmd == "WITH":
|
||||
if _detect_writable_cte(s):
|
||||
return True
|
||||
main_q = _extract_main_query_after_cte(s)
|
||||
if main_q:
|
||||
return main_q.split()[0].upper().rstrip(";") in _WRITE_COMMANDS
|
||||
return False
|
||||
|
||||
is_write = any(_is_write_stmt(s) for s in _stmts)
|
||||
|
||||
engine = create_engine(self.db_uri)
|
||||
Session = sessionmaker(bind=engine) # noqa: N806
|
||||
session = Session()
|
||||
try:
|
||||
result = session.execute(text(sql_query))
|
||||
session.commit()
|
||||
result = session.execute(text(sql_query), params or {})
|
||||
|
||||
# Only commit when the operation actually mutates state
|
||||
if self.allow_dml and is_write:
|
||||
session.commit()
|
||||
|
||||
if result.returns_rows: # type: ignore[attr-defined]
|
||||
columns = result.keys()
|
||||
|
||||
Reference in New Issue
Block a user