Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
a8bf69e05b Fix ruff format for nl2sql_tool.py
Co-Authored-By: João <joao@crewai.com>
2026-06-14 01:54:57 +00:00
github-actions[bot]
29a39cfeef chore: update tool specifications 2026-06-14 01:50:15 +00:00
Devin AI
7575d9b64a Fix #6153: Add input validation hooks for memory/RAG ingestion and approval flag for NL2SQLTool
- Add optional content_filter callable on KnowledgeStorage (save/asave)
- Add optional content_filter callable on CrewAIRagAdapter (add)
- Add optional require_approval flag and approval_handler on NL2SQLTool
- Add comprehensive tests for all three features

Co-Authored-By: João <joao@crewai.com>
2026-06-14 01:48:38 +00:00
7 changed files with 374 additions and 1 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Callable
import hashlib
from typing import TYPE_CHECKING, Any, cast
import uuid
@@ -54,6 +55,7 @@ class CrewAIRagAdapter(Adapter):
similarity_threshold: float = 0.6
limit: int = 5
config: RagConfigType | None = None
content_filter: Callable[[list[str]], list[str]] | None = None
_client: BaseClient | None = PrivateAttr(default=None)
def model_post_init(self, __context: Any) -> None:
@@ -348,6 +350,15 @@ class CrewAIRagAdapter(Adapter):
)
if documents:
if self.content_filter is not None:
filtered_contents = set(
self.content_filter([doc["content"] for doc in documents])
)
documents = [
doc for doc in documents if doc["content"] in filtered_contents
]
if not documents:
return
if self._client is None:
raise ValueError("Client is not initialized")
self._client.add_documents(

View File

@@ -1,4 +1,4 @@
from collections.abc import Iterator
from collections.abc import Callable, Iterator
import logging
import os
import re
@@ -246,6 +246,26 @@ class NL2SQLTool(BaseTool):
"write operations."
),
)
require_approval: bool = Field(
default=False,
title="Require Approval",
description=(
"When True, every query is shown to a human for approval "
"before execution. The approval_handler callable is invoked "
"with the SQL string and must return True to proceed. "
"Defaults to an interactive terminal prompt."
),
)
approval_handler: Callable[[str], bool] | None = Field(
default=None,
exclude=True,
description=(
"Custom callable invoked when require_approval is True. "
"Receives the SQL query string and must return True to "
"allow execution or False to reject it. When None, a "
"built-in interactive terminal prompt is used."
),
)
tables: list[dict[str, Any]] = Field(default_factory=list)
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
args_schema: type[BaseModel] = NL2SQLToolInput
@@ -420,9 +440,31 @@ class NL2SQLTool(BaseTool):
# Core execution
def _request_approval(self, sql_query: str) -> bool:
"""Ask for human approval before executing the query.
Uses ``approval_handler`` if provided, otherwise falls back to an
interactive terminal prompt via ``input()``.
"""
if self.approval_handler is not None:
return self.approval_handler(sql_query)
try:
answer = input(
f"\n[NL2SQLTool] The following query requires approval "
f"before execution:\n\n {sql_query}\n\n"
f"Execute this query? (y/n): "
)
except (EOFError, KeyboardInterrupt):
return False
return answer.strip().lower() in ("y", "yes")
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
try:
self._validate_query(sql_query)
if self.require_approval and not self._request_approval(sql_query):
return (
f"Query execution was rejected by the human reviewer: {sql_query}"
)
data = self.execute_sql(sql_query)
except ValueError:
raise

View File

@@ -0,0 +1,96 @@
"""Tests for CrewAIRagAdapter.content_filter."""
from unittest.mock import MagicMock, patch
import pytest
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
def _make_adapter(
content_filter=None,
collection_name: str = "test_collection",
) -> CrewAIRagAdapter:
"""Build a CrewAIRagAdapter with a mocked RAG client."""
mock_client = MagicMock()
with patch(
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
return_value=mock_client,
):
adapter = CrewAIRagAdapter(
collection_name=collection_name,
content_filter=content_filter,
)
return adapter
class TestContentFilterOnAdd:
def test_filter_removes_documents(self) -> None:
"""Documents whose content is rejected by the filter are not indexed."""
def drop_secrets(contents: list[str]) -> list[str]:
return [c for c in contents if "SECRET" not in c]
adapter = _make_adapter(content_filter=drop_secrets)
mock_client = adapter._client
assert mock_client is not None
adapter.add(
"safe text",
data_type="text",
)
# The add method processes the text into BaseRecord documents.
# With the filter, only safe ones should pass.
if mock_client.add_documents.called:
docs = mock_client.add_documents.call_args.kwargs["documents"]
for doc in docs:
assert "SECRET" not in doc["content"]
def test_filter_drops_all_skips_add(self) -> None:
"""When the filter removes every document, add_documents is not called."""
adapter = _make_adapter(content_filter=lambda contents: [])
mock_client = adapter._client
assert mock_client is not None
adapter.add("anything", data_type="text")
mock_client.add_documents.assert_not_called()
def test_filter_exception_propagates(self) -> None:
"""An exception from content_filter aborts the add."""
def exploding_filter(contents: list[str]) -> list[str]:
raise ValueError("Policy violation")
adapter = _make_adapter(content_filter=exploding_filter)
with pytest.raises(ValueError, match="Policy violation"):
adapter.add("content", data_type="text")
def test_no_filter_is_noop(self) -> None:
"""When content_filter is None, documents are persisted normally."""
adapter = _make_adapter(content_filter=None)
assert adapter.content_filter is None
mock_client = adapter._client
assert mock_client is not None
adapter.add("hello world", data_type="text")
mock_client.add_documents.assert_called_once()
docs = mock_client.add_documents.call_args.kwargs["documents"]
assert len(docs) >= 1
def test_filter_receives_all_content_strings(self) -> None:
"""The filter callable receives the full list of content strings."""
received: list[list[str]] = []
def capturing_filter(contents: list[str]) -> list[str]:
received.append(contents)
return contents
adapter = _make_adapter(content_filter=capturing_filter)
adapter.add("some text content", data_type="text")
assert len(received) == 1
assert all(isinstance(c, str) for c in received[0])

View File

@@ -598,3 +598,85 @@ class TestCTEUnknownCommand:
tool = _make_tool(allow_dml=False)
with pytest.raises(ValueError, match="unrecognised"):
tool._validate_query("WITH cte AS (SELECT 1) FOOBAR")
# --- require_approval tests ---
class TestRequireApproval:
def test_approval_granted_executes_query(self):
"""When the approval handler returns True, the query runs normally."""
tool = _make_tool(
require_approval=True,
approval_handler=lambda sql: True,
)
result = tool._run("SELECT 1 AS val")
assert result == [{"val": 1}]
def test_approval_rejected_blocks_query(self):
"""When the approval handler returns False, execution is blocked."""
tool = _make_tool(
require_approval=True,
approval_handler=lambda sql: False,
)
result = tool._run("SELECT 1 AS val")
assert "rejected" in result.lower()
def test_approval_handler_receives_sql_string(self):
"""The approval_handler receives the exact SQL query string."""
received: list[str] = []
def spy(sql: str) -> bool:
received.append(sql)
return True
tool = _make_tool(require_approval=True, approval_handler=spy)
tool._run("SELECT 42 AS answer")
assert received == ["SELECT 42 AS answer"]
def test_no_approval_when_flag_is_false(self):
"""require_approval=False never invokes the handler."""
handler = MagicMock(return_value=True)
tool = _make_tool(require_approval=False, approval_handler=handler)
tool._run("SELECT 1")
handler.assert_not_called()
def test_default_prompt_on_eof(self):
"""The built-in prompt returns False when input() raises EOFError."""
tool = _make_tool(require_approval=True)
with patch("builtins.input", side_effect=EOFError):
result = tool._run("SELECT 1")
assert "rejected" in result.lower()
def test_default_prompt_yes(self):
"""The built-in prompt allows execution when user types 'y'."""
tool = _make_tool(require_approval=True)
with patch("builtins.input", return_value="y"):
result = tool._run("SELECT 1 AS val")
assert result == [{"val": 1}]
def test_default_prompt_no(self):
"""The built-in prompt blocks execution when user types 'n'."""
tool = _make_tool(require_approval=True)
with patch("builtins.input", return_value="n"):
result = tool._run("SELECT 1")
assert "rejected" in result.lower()
def test_approval_checked_after_validation(self):
"""Validation runs before approval — blocked queries never reach the handler."""
handler = MagicMock(return_value=True)
tool = _make_tool(
allow_dml=False,
require_approval=True,
approval_handler=handler,
)
with pytest.raises(ValueError, match="read-only mode"):
tool._run("DROP TABLE users")
handler.assert_not_called()
def test_approval_with_keyboard_interrupt(self):
"""KeyboardInterrupt during input() rejects the query gracefully."""
tool = _make_tool(require_approval=True)
with patch("builtins.input", side_effect=KeyboardInterrupt):
result = tool._run("SELECT 1")
assert "rejected" in result.lower()

View File

@@ -15870,6 +15870,12 @@
"title": "Database URI",
"type": "string"
},
"require_approval": {
"default": false,
"description": "When True, every query is shown to a human for approval before execution. The approval_handler callable is invoked with the SQL string and must return True to proceed. Defaults to an interactive terminal prompt.",
"title": "Require Approval",
"type": "boolean"
},
"tables": {
"items": {
"additionalProperties": true,

View File

@@ -1,3 +1,4 @@
from collections.abc import Callable
import logging
import traceback
from typing import Any, cast
@@ -32,6 +33,16 @@ class KnowledgeStorage(BaseKnowledgeStorage):
| type[BaseEmbeddingsProvider[Any]]
| None
) = Field(default=None, exclude=True)
content_filter: Callable[[list[str]], list[str]] | None = Field(
default=None,
exclude=True,
description=(
"Optional callable that inspects and filters documents before "
"they are indexed. Receives the full document list and must "
"return the (possibly filtered) list to persist. Raise an "
"exception inside the callable to abort the save entirely."
),
)
_client: BaseClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
@@ -106,6 +117,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
if not documents:
return
if self.content_filter is not None:
documents = self.content_filter(documents)
if not documents:
return
try:
client = self._get_client()
collection_name = (
@@ -187,6 +203,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
if not documents:
return
if self.content_filter is not None:
documents = self.content_filter(documents)
if not documents:
return
try:
client = self._get_client()
collection_name = (

View File

@@ -193,3 +193,118 @@ def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
storage.save(["test document"])
# --- content_filter tests ---
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_content_filter_removes_documents(mock_get_client: MagicMock) -> None:
"""content_filter can drop specific documents before indexing."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
def reject_secrets(docs: list[str]) -> list[str]:
return [d for d in docs if "SECRET" not in d]
storage = KnowledgeStorage(
collection_name="filter_test", content_filter=reject_secrets
)
storage.save(["safe content", "contains SECRET key", "also safe"])
mock_client.add_documents.assert_called_once()
added = mock_client.add_documents.call_args.kwargs["documents"]
contents = [doc["content"] for doc in added]
assert contents == ["safe content", "also safe"]
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_content_filter_returns_empty_skips_save(mock_get_client: MagicMock) -> None:
"""When content_filter filters out all documents, save is skipped entirely."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(
collection_name="empty_filter", content_filter=lambda docs: []
)
storage.save(["doc1", "doc2"])
mock_client.add_documents.assert_not_called()
mock_client.get_or_create_collection.assert_not_called()
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_content_filter_exception_propagates(mock_get_client: MagicMock) -> None:
"""Exceptions raised inside content_filter abort the save."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
def strict_filter(docs: list[str]) -> list[str]:
raise ValueError("Blocked by policy")
storage = KnowledgeStorage(
collection_name="strict_test", content_filter=strict_filter
)
with pytest.raises(ValueError, match="Blocked by policy"):
storage.save(["some content"])
mock_client.add_documents.assert_not_called()
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
def test_content_filter_none_is_noop(mock_get_client: MagicMock) -> None:
"""When content_filter is None (default), all documents are saved."""
mock_client = MagicMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(collection_name="noop_test")
assert storage.content_filter is None
storage.save(["doc1", "doc2"])
mock_client.add_documents.assert_called_once()
added = mock_client.add_documents.call_args.kwargs["documents"]
assert len(added) == 2
@pytest.mark.asyncio
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
async def test_content_filter_async_save(mock_get_client: MagicMock) -> None:
"""content_filter is applied in asave() as well."""
from unittest.mock import AsyncMock
mock_client = MagicMock()
mock_client.aget_or_create_collection = AsyncMock()
mock_client.aadd_documents = AsyncMock()
mock_get_client.return_value = mock_client
def only_short(docs: list[str]) -> list[str]:
return [d for d in docs if len(d) < 20]
storage = KnowledgeStorage(
collection_name="async_filter", content_filter=only_short
)
await storage.asave(["short", "this is a much longer document string"])
mock_client.aadd_documents.assert_called_once()
added = mock_client.aadd_documents.call_args.kwargs["documents"]
assert len(added) == 1
assert added[0]["content"] == "short"
@pytest.mark.asyncio
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
async def test_content_filter_async_all_filtered(mock_get_client: MagicMock) -> None:
"""asave() skips persistence when content_filter removes everything."""
from unittest.mock import AsyncMock
mock_client = MagicMock()
mock_client.aget_or_create_collection = AsyncMock()
mock_client.aadd_documents = AsyncMock()
mock_get_client.return_value = mock_client
storage = KnowledgeStorage(
collection_name="async_empty", content_filter=lambda docs: []
)
await storage.asave(["doc1"])
mock_client.aadd_documents.assert_not_called()