From 7575d9b64abc66eb8245cbd504d9015b4515a130 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sun, 14 Jun 2026 01:48:38 +0000 Subject: [PATCH] Fix #6153: Add input validation hooks for memory/RAG ingestion and approval flag for NL2SQLTool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add optional content_filter callable on KnowledgeStorage (save/asave) - Add optional content_filter callable on CrewAIRagAdapter (add) - Add optional require_approval flag and approval_handler on NL2SQLTool - Add comprehensive tests for all three features Co-Authored-By: João --- .../adapters/crewai_rag_adapter.py | 11 ++ .../crewai_tools/tools/nl2sql/nl2sql_tool.py | 45 ++++++- .../test_crewai_rag_adapter_content_filter.py | 96 +++++++++++++++ .../tests/tools/test_nl2sql_security.py | 82 +++++++++++++ .../knowledge/storage/knowledge_storage.py | 21 ++++ .../test_knowledge_storage_integration.py | 115 ++++++++++++++++++ 6 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 lib/crewai-tools/tests/adapters/test_crewai_rag_adapter_content_filter.py diff --git a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py index b0a655830..b146e97cb 100644 --- a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py +++ b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Callable import hashlib from typing import TYPE_CHECKING, Any, cast import uuid @@ -54,6 +55,7 @@ class CrewAIRagAdapter(Adapter): similarity_threshold: float = 0.6 limit: int = 5 config: RagConfigType | None = None + content_filter: Callable[[list[str]], list[str]] | None = None _client: BaseClient | None = PrivateAttr(default=None) def model_post_init(self, __context: Any) -> None: @@ -348,6 +350,15 @@ class CrewAIRagAdapter(Adapter): ) if documents: + if self.content_filter is not None: + filtered_contents = set( + self.content_filter([doc["content"] for doc in documents]) + ) + documents = [ + doc for doc in documents if doc["content"] in filtered_contents + ] + if not documents: + return if self._client is None: raise ValueError("Client is not initialized") self._client.add_documents( diff --git a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py index 818c61dd2..9e5e931cd 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/nl2sql/nl2sql_tool.py @@ -1,4 +1,4 @@ -from collections.abc import Iterator +from collections.abc import Callable, Iterator import logging import os import re @@ -246,6 +246,26 @@ class NL2SQLTool(BaseTool): "write operations." ), ) + require_approval: bool = Field( + default=False, + title="Require Approval", + description=( + "When True, every query is shown to a human for approval " + "before execution. The approval_handler callable is invoked " + "with the SQL string and must return True to proceed. " + "Defaults to an interactive terminal prompt." + ), + ) + approval_handler: Callable[[str], bool] | None = Field( + default=None, + exclude=True, + description=( + "Custom callable invoked when require_approval is True. " + "Receives the SQL query string and must return True to " + "allow execution or False to reject it. When None, a " + "built-in interactive terminal prompt is used." + ), + ) tables: list[dict[str, Any]] = Field(default_factory=list) columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict) args_schema: type[BaseModel] = NL2SQLToolInput @@ -420,9 +440,32 @@ class NL2SQLTool(BaseTool): # Core execution + def _request_approval(self, sql_query: str) -> bool: + """Ask for human approval before executing the query. + + Uses ``approval_handler`` if provided, otherwise falls back to an + interactive terminal prompt via ``input()``. + """ + if self.approval_handler is not None: + return self.approval_handler(sql_query) + try: + answer = input( + f"\n[NL2SQLTool] The following query requires approval " + f"before execution:\n\n {sql_query}\n\n" + f"Execute this query? (y/n): " + ) + except (EOFError, KeyboardInterrupt): + return False + return answer.strip().lower() in ("y", "yes") + def _run(self, sql_query: str) -> list[dict[str, Any]] | str: try: self._validate_query(sql_query) + if self.require_approval and not self._request_approval(sql_query): + return ( + f"Query execution was rejected by the human reviewer: " + f"{sql_query}" + ) data = self.execute_sql(sql_query) except ValueError: raise diff --git a/lib/crewai-tools/tests/adapters/test_crewai_rag_adapter_content_filter.py b/lib/crewai-tools/tests/adapters/test_crewai_rag_adapter_content_filter.py new file mode 100644 index 000000000..bbd159e22 --- /dev/null +++ b/lib/crewai-tools/tests/adapters/test_crewai_rag_adapter_content_filter.py @@ -0,0 +1,96 @@ +"""Tests for CrewAIRagAdapter.content_filter.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter + + +def _make_adapter( + content_filter=None, + collection_name: str = "test_collection", +) -> CrewAIRagAdapter: + """Build a CrewAIRagAdapter with a mocked RAG client.""" + mock_client = MagicMock() + with patch( + "crewai_tools.adapters.crewai_rag_adapter.get_rag_client", + return_value=mock_client, + ): + adapter = CrewAIRagAdapter( + collection_name=collection_name, + content_filter=content_filter, + ) + return adapter + + +class TestContentFilterOnAdd: + def test_filter_removes_documents(self) -> None: + """Documents whose content is rejected by the filter are not indexed.""" + + def drop_secrets(contents: list[str]) -> list[str]: + return [c for c in contents if "SECRET" not in c] + + adapter = _make_adapter(content_filter=drop_secrets) + mock_client = adapter._client + assert mock_client is not None + + adapter.add( + "safe text", + data_type="text", + ) + # The add method processes the text into BaseRecord documents. + # With the filter, only safe ones should pass. + if mock_client.add_documents.called: + docs = mock_client.add_documents.call_args.kwargs["documents"] + for doc in docs: + assert "SECRET" not in doc["content"] + + def test_filter_drops_all_skips_add(self) -> None: + """When the filter removes every document, add_documents is not called.""" + adapter = _make_adapter(content_filter=lambda contents: []) + mock_client = adapter._client + assert mock_client is not None + + adapter.add("anything", data_type="text") + + mock_client.add_documents.assert_not_called() + + def test_filter_exception_propagates(self) -> None: + """An exception from content_filter aborts the add.""" + + def exploding_filter(contents: list[str]) -> list[str]: + raise ValueError("Policy violation") + + adapter = _make_adapter(content_filter=exploding_filter) + + with pytest.raises(ValueError, match="Policy violation"): + adapter.add("content", data_type="text") + + def test_no_filter_is_noop(self) -> None: + """When content_filter is None, documents are persisted normally.""" + adapter = _make_adapter(content_filter=None) + assert adapter.content_filter is None + mock_client = adapter._client + assert mock_client is not None + + adapter.add("hello world", data_type="text") + + mock_client.add_documents.assert_called_once() + docs = mock_client.add_documents.call_args.kwargs["documents"] + assert len(docs) >= 1 + + def test_filter_receives_all_content_strings(self) -> None: + """The filter callable receives the full list of content strings.""" + received: list[list[str]] = [] + + def capturing_filter(contents: list[str]) -> list[str]: + received.append(contents) + return contents + + adapter = _make_adapter(content_filter=capturing_filter) + + adapter.add("some text content", data_type="text") + + assert len(received) == 1 + assert all(isinstance(c, str) for c in received[0]) diff --git a/lib/crewai-tools/tests/tools/test_nl2sql_security.py b/lib/crewai-tools/tests/tools/test_nl2sql_security.py index aedfad281..585b3c4ac 100644 --- a/lib/crewai-tools/tests/tools/test_nl2sql_security.py +++ b/lib/crewai-tools/tests/tools/test_nl2sql_security.py @@ -598,3 +598,85 @@ class TestCTEUnknownCommand: tool = _make_tool(allow_dml=False) with pytest.raises(ValueError, match="unrecognised"): tool._validate_query("WITH cte AS (SELECT 1) FOOBAR") + + +# --- require_approval tests --- + + +class TestRequireApproval: + def test_approval_granted_executes_query(self): + """When the approval handler returns True, the query runs normally.""" + tool = _make_tool( + require_approval=True, + approval_handler=lambda sql: True, + ) + result = tool._run("SELECT 1 AS val") + assert result == [{"val": 1}] + + def test_approval_rejected_blocks_query(self): + """When the approval handler returns False, execution is blocked.""" + tool = _make_tool( + require_approval=True, + approval_handler=lambda sql: False, + ) + result = tool._run("SELECT 1 AS val") + assert "rejected" in result.lower() + + def test_approval_handler_receives_sql_string(self): + """The approval_handler receives the exact SQL query string.""" + received: list[str] = [] + + def spy(sql: str) -> bool: + received.append(sql) + return True + + tool = _make_tool(require_approval=True, approval_handler=spy) + tool._run("SELECT 42 AS answer") + assert received == ["SELECT 42 AS answer"] + + def test_no_approval_when_flag_is_false(self): + """require_approval=False never invokes the handler.""" + handler = MagicMock(return_value=True) + tool = _make_tool(require_approval=False, approval_handler=handler) + tool._run("SELECT 1") + handler.assert_not_called() + + def test_default_prompt_on_eof(self): + """The built-in prompt returns False when input() raises EOFError.""" + tool = _make_tool(require_approval=True) + with patch("builtins.input", side_effect=EOFError): + result = tool._run("SELECT 1") + assert "rejected" in result.lower() + + def test_default_prompt_yes(self): + """The built-in prompt allows execution when user types 'y'.""" + tool = _make_tool(require_approval=True) + with patch("builtins.input", return_value="y"): + result = tool._run("SELECT 1 AS val") + assert result == [{"val": 1}] + + def test_default_prompt_no(self): + """The built-in prompt blocks execution when user types 'n'.""" + tool = _make_tool(require_approval=True) + with patch("builtins.input", return_value="n"): + result = tool._run("SELECT 1") + assert "rejected" in result.lower() + + def test_approval_checked_after_validation(self): + """Validation runs before approval — blocked queries never reach the handler.""" + handler = MagicMock(return_value=True) + tool = _make_tool( + allow_dml=False, + require_approval=True, + approval_handler=handler, + ) + with pytest.raises(ValueError, match="read-only mode"): + tool._run("DROP TABLE users") + handler.assert_not_called() + + def test_approval_with_keyboard_interrupt(self): + """KeyboardInterrupt during input() rejects the query gracefully.""" + tool = _make_tool(require_approval=True) + with patch("builtins.input", side_effect=KeyboardInterrupt): + result = tool._run("SELECT 1") + assert "rejected" in result.lower() diff --git a/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py b/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py index 3c9615946..9e6812254 100644 --- a/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py +++ b/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py @@ -1,3 +1,4 @@ +from collections.abc import Callable import logging import traceback from typing import Any, cast @@ -32,6 +33,16 @@ class KnowledgeStorage(BaseKnowledgeStorage): | type[BaseEmbeddingsProvider[Any]] | None ) = Field(default=None, exclude=True) + content_filter: Callable[[list[str]], list[str]] | None = Field( + default=None, + exclude=True, + description=( + "Optional callable that inspects and filters documents before " + "they are indexed. Receives the full document list and must " + "return the (possibly filtered) list to persist. Raise an " + "exception inside the callable to abort the save entirely." + ), + ) _client: BaseClient | None = PrivateAttr(default=None) @model_validator(mode="after") @@ -106,6 +117,11 @@ class KnowledgeStorage(BaseKnowledgeStorage): if not documents: return + if self.content_filter is not None: + documents = self.content_filter(documents) + if not documents: + return + try: client = self._get_client() collection_name = ( @@ -187,6 +203,11 @@ class KnowledgeStorage(BaseKnowledgeStorage): if not documents: return + if self.content_filter is not None: + documents = self.content_filter(documents) + if not documents: + return + try: client = self._get_client() collection_name = ( diff --git a/lib/crewai/tests/knowledge/test_knowledge_storage_integration.py b/lib/crewai/tests/knowledge/test_knowledge_storage_integration.py index 5a228cde4..627b2b0a7 100644 --- a/lib/crewai/tests/knowledge/test_knowledge_storage_integration.py +++ b/lib/crewai/tests/knowledge/test_knowledge_storage_integration.py @@ -193,3 +193,118 @@ def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None: with pytest.raises(ValueError, match="Embedding dimension mismatch"): storage.save(["test document"]) + + +# --- content_filter tests --- + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_content_filter_removes_documents(mock_get_client: MagicMock) -> None: + """content_filter can drop specific documents before indexing.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + def reject_secrets(docs: list[str]) -> list[str]: + return [d for d in docs if "SECRET" not in d] + + storage = KnowledgeStorage( + collection_name="filter_test", content_filter=reject_secrets + ) + storage.save(["safe content", "contains SECRET key", "also safe"]) + + mock_client.add_documents.assert_called_once() + added = mock_client.add_documents.call_args.kwargs["documents"] + contents = [doc["content"] for doc in added] + assert contents == ["safe content", "also safe"] + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_content_filter_returns_empty_skips_save(mock_get_client: MagicMock) -> None: + """When content_filter filters out all documents, save is skipped entirely.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + storage = KnowledgeStorage( + collection_name="empty_filter", content_filter=lambda docs: [] + ) + storage.save(["doc1", "doc2"]) + + mock_client.add_documents.assert_not_called() + mock_client.get_or_create_collection.assert_not_called() + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_content_filter_exception_propagates(mock_get_client: MagicMock) -> None: + """Exceptions raised inside content_filter abort the save.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + def strict_filter(docs: list[str]) -> list[str]: + raise ValueError("Blocked by policy") + + storage = KnowledgeStorage( + collection_name="strict_test", content_filter=strict_filter + ) + with pytest.raises(ValueError, match="Blocked by policy"): + storage.save(["some content"]) + + mock_client.add_documents.assert_not_called() + + +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +def test_content_filter_none_is_noop(mock_get_client: MagicMock) -> None: + """When content_filter is None (default), all documents are saved.""" + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + storage = KnowledgeStorage(collection_name="noop_test") + assert storage.content_filter is None + storage.save(["doc1", "doc2"]) + + mock_client.add_documents.assert_called_once() + added = mock_client.add_documents.call_args.kwargs["documents"] + assert len(added) == 2 + + +@pytest.mark.asyncio +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +async def test_content_filter_async_save(mock_get_client: MagicMock) -> None: + """content_filter is applied in asave() as well.""" + from unittest.mock import AsyncMock + + mock_client = MagicMock() + mock_client.aget_or_create_collection = AsyncMock() + mock_client.aadd_documents = AsyncMock() + mock_get_client.return_value = mock_client + + def only_short(docs: list[str]) -> list[str]: + return [d for d in docs if len(d) < 20] + + storage = KnowledgeStorage( + collection_name="async_filter", content_filter=only_short + ) + await storage.asave(["short", "this is a much longer document string"]) + + mock_client.aadd_documents.assert_called_once() + added = mock_client.aadd_documents.call_args.kwargs["documents"] + assert len(added) == 1 + assert added[0]["content"] == "short" + + +@pytest.mark.asyncio +@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client") +async def test_content_filter_async_all_filtered(mock_get_client: MagicMock) -> None: + """asave() skips persistence when content_filter removes everything.""" + from unittest.mock import AsyncMock + + mock_client = MagicMock() + mock_client.aget_or_create_collection = AsyncMock() + mock_client.aadd_documents = AsyncMock() + mock_get_client.return_value = mock_client + + storage = KnowledgeStorage( + collection_name="async_empty", content_filter=lambda docs: [] + ) + await storage.asave(["doc1"]) + + mock_client.aadd_documents.assert_not_called()