mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-15 13:18:09 +00:00
Compare commits
3 Commits
flow-itera
...
devin/1781
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a8bf69e05b | ||
|
|
29a39cfeef | ||
|
|
7575d9b64a |
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
import uuid
|
||||
@@ -54,6 +55,7 @@ class CrewAIRagAdapter(Adapter):
|
||||
similarity_threshold: float = 0.6
|
||||
limit: int = 5
|
||||
config: RagConfigType | None = None
|
||||
content_filter: Callable[[list[str]], list[str]] | None = None
|
||||
_client: BaseClient | None = PrivateAttr(default=None)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
@@ -348,6 +350,15 @@ class CrewAIRagAdapter(Adapter):
|
||||
)
|
||||
|
||||
if documents:
|
||||
if self.content_filter is not None:
|
||||
filtered_contents = set(
|
||||
self.content_filter([doc["content"] for doc in documents])
|
||||
)
|
||||
documents = [
|
||||
doc for doc in documents if doc["content"] in filtered_contents
|
||||
]
|
||||
if not documents:
|
||||
return
|
||||
if self._client is None:
|
||||
raise ValueError("Client is not initialized")
|
||||
self._client.add_documents(
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Callable, Iterator
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -246,6 +246,26 @@ class NL2SQLTool(BaseTool):
|
||||
"write operations."
|
||||
),
|
||||
)
|
||||
require_approval: bool = Field(
|
||||
default=False,
|
||||
title="Require Approval",
|
||||
description=(
|
||||
"When True, every query is shown to a human for approval "
|
||||
"before execution. The approval_handler callable is invoked "
|
||||
"with the SQL string and must return True to proceed. "
|
||||
"Defaults to an interactive terminal prompt."
|
||||
),
|
||||
)
|
||||
approval_handler: Callable[[str], bool] | None = Field(
|
||||
default=None,
|
||||
exclude=True,
|
||||
description=(
|
||||
"Custom callable invoked when require_approval is True. "
|
||||
"Receives the SQL query string and must return True to "
|
||||
"allow execution or False to reject it. When None, a "
|
||||
"built-in interactive terminal prompt is used."
|
||||
),
|
||||
)
|
||||
tables: list[dict[str, Any]] = Field(default_factory=list)
|
||||
columns: dict[str, list[dict[str, Any]] | str] = Field(default_factory=dict)
|
||||
args_schema: type[BaseModel] = NL2SQLToolInput
|
||||
@@ -420,9 +440,31 @@ class NL2SQLTool(BaseTool):
|
||||
|
||||
# Core execution
|
||||
|
||||
def _request_approval(self, sql_query: str) -> bool:
|
||||
"""Ask for human approval before executing the query.
|
||||
|
||||
Uses ``approval_handler`` if provided, otherwise falls back to an
|
||||
interactive terminal prompt via ``input()``.
|
||||
"""
|
||||
if self.approval_handler is not None:
|
||||
return self.approval_handler(sql_query)
|
||||
try:
|
||||
answer = input(
|
||||
f"\n[NL2SQLTool] The following query requires approval "
|
||||
f"before execution:\n\n {sql_query}\n\n"
|
||||
f"Execute this query? (y/n): "
|
||||
)
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return False
|
||||
return answer.strip().lower() in ("y", "yes")
|
||||
|
||||
def _run(self, sql_query: str) -> list[dict[str, Any]] | str:
|
||||
try:
|
||||
self._validate_query(sql_query)
|
||||
if self.require_approval and not self._request_approval(sql_query):
|
||||
return (
|
||||
f"Query execution was rejected by the human reviewer: {sql_query}"
|
||||
)
|
||||
data = self.execute_sql(sql_query)
|
||||
except ValueError:
|
||||
raise
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Tests for CrewAIRagAdapter.content_filter."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter
|
||||
|
||||
|
||||
def _make_adapter(
|
||||
content_filter=None,
|
||||
collection_name: str = "test_collection",
|
||||
) -> CrewAIRagAdapter:
|
||||
"""Build a CrewAIRagAdapter with a mocked RAG client."""
|
||||
mock_client = MagicMock()
|
||||
with patch(
|
||||
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
|
||||
return_value=mock_client,
|
||||
):
|
||||
adapter = CrewAIRagAdapter(
|
||||
collection_name=collection_name,
|
||||
content_filter=content_filter,
|
||||
)
|
||||
return adapter
|
||||
|
||||
|
||||
class TestContentFilterOnAdd:
|
||||
def test_filter_removes_documents(self) -> None:
|
||||
"""Documents whose content is rejected by the filter are not indexed."""
|
||||
|
||||
def drop_secrets(contents: list[str]) -> list[str]:
|
||||
return [c for c in contents if "SECRET" not in c]
|
||||
|
||||
adapter = _make_adapter(content_filter=drop_secrets)
|
||||
mock_client = adapter._client
|
||||
assert mock_client is not None
|
||||
|
||||
adapter.add(
|
||||
"safe text",
|
||||
data_type="text",
|
||||
)
|
||||
# The add method processes the text into BaseRecord documents.
|
||||
# With the filter, only safe ones should pass.
|
||||
if mock_client.add_documents.called:
|
||||
docs = mock_client.add_documents.call_args.kwargs["documents"]
|
||||
for doc in docs:
|
||||
assert "SECRET" not in doc["content"]
|
||||
|
||||
def test_filter_drops_all_skips_add(self) -> None:
|
||||
"""When the filter removes every document, add_documents is not called."""
|
||||
adapter = _make_adapter(content_filter=lambda contents: [])
|
||||
mock_client = adapter._client
|
||||
assert mock_client is not None
|
||||
|
||||
adapter.add("anything", data_type="text")
|
||||
|
||||
mock_client.add_documents.assert_not_called()
|
||||
|
||||
def test_filter_exception_propagates(self) -> None:
|
||||
"""An exception from content_filter aborts the add."""
|
||||
|
||||
def exploding_filter(contents: list[str]) -> list[str]:
|
||||
raise ValueError("Policy violation")
|
||||
|
||||
adapter = _make_adapter(content_filter=exploding_filter)
|
||||
|
||||
with pytest.raises(ValueError, match="Policy violation"):
|
||||
adapter.add("content", data_type="text")
|
||||
|
||||
def test_no_filter_is_noop(self) -> None:
|
||||
"""When content_filter is None, documents are persisted normally."""
|
||||
adapter = _make_adapter(content_filter=None)
|
||||
assert adapter.content_filter is None
|
||||
mock_client = adapter._client
|
||||
assert mock_client is not None
|
||||
|
||||
adapter.add("hello world", data_type="text")
|
||||
|
||||
mock_client.add_documents.assert_called_once()
|
||||
docs = mock_client.add_documents.call_args.kwargs["documents"]
|
||||
assert len(docs) >= 1
|
||||
|
||||
def test_filter_receives_all_content_strings(self) -> None:
|
||||
"""The filter callable receives the full list of content strings."""
|
||||
received: list[list[str]] = []
|
||||
|
||||
def capturing_filter(contents: list[str]) -> list[str]:
|
||||
received.append(contents)
|
||||
return contents
|
||||
|
||||
adapter = _make_adapter(content_filter=capturing_filter)
|
||||
|
||||
adapter.add("some text content", data_type="text")
|
||||
|
||||
assert len(received) == 1
|
||||
assert all(isinstance(c, str) for c in received[0])
|
||||
@@ -598,3 +598,85 @@ class TestCTEUnknownCommand:
|
||||
tool = _make_tool(allow_dml=False)
|
||||
with pytest.raises(ValueError, match="unrecognised"):
|
||||
tool._validate_query("WITH cte AS (SELECT 1) FOOBAR")
|
||||
|
||||
|
||||
# --- require_approval tests ---
|
||||
|
||||
|
||||
class TestRequireApproval:
|
||||
def test_approval_granted_executes_query(self):
|
||||
"""When the approval handler returns True, the query runs normally."""
|
||||
tool = _make_tool(
|
||||
require_approval=True,
|
||||
approval_handler=lambda sql: True,
|
||||
)
|
||||
result = tool._run("SELECT 1 AS val")
|
||||
assert result == [{"val": 1}]
|
||||
|
||||
def test_approval_rejected_blocks_query(self):
|
||||
"""When the approval handler returns False, execution is blocked."""
|
||||
tool = _make_tool(
|
||||
require_approval=True,
|
||||
approval_handler=lambda sql: False,
|
||||
)
|
||||
result = tool._run("SELECT 1 AS val")
|
||||
assert "rejected" in result.lower()
|
||||
|
||||
def test_approval_handler_receives_sql_string(self):
|
||||
"""The approval_handler receives the exact SQL query string."""
|
||||
received: list[str] = []
|
||||
|
||||
def spy(sql: str) -> bool:
|
||||
received.append(sql)
|
||||
return True
|
||||
|
||||
tool = _make_tool(require_approval=True, approval_handler=spy)
|
||||
tool._run("SELECT 42 AS answer")
|
||||
assert received == ["SELECT 42 AS answer"]
|
||||
|
||||
def test_no_approval_when_flag_is_false(self):
|
||||
"""require_approval=False never invokes the handler."""
|
||||
handler = MagicMock(return_value=True)
|
||||
tool = _make_tool(require_approval=False, approval_handler=handler)
|
||||
tool._run("SELECT 1")
|
||||
handler.assert_not_called()
|
||||
|
||||
def test_default_prompt_on_eof(self):
|
||||
"""The built-in prompt returns False when input() raises EOFError."""
|
||||
tool = _make_tool(require_approval=True)
|
||||
with patch("builtins.input", side_effect=EOFError):
|
||||
result = tool._run("SELECT 1")
|
||||
assert "rejected" in result.lower()
|
||||
|
||||
def test_default_prompt_yes(self):
|
||||
"""The built-in prompt allows execution when user types 'y'."""
|
||||
tool = _make_tool(require_approval=True)
|
||||
with patch("builtins.input", return_value="y"):
|
||||
result = tool._run("SELECT 1 AS val")
|
||||
assert result == [{"val": 1}]
|
||||
|
||||
def test_default_prompt_no(self):
|
||||
"""The built-in prompt blocks execution when user types 'n'."""
|
||||
tool = _make_tool(require_approval=True)
|
||||
with patch("builtins.input", return_value="n"):
|
||||
result = tool._run("SELECT 1")
|
||||
assert "rejected" in result.lower()
|
||||
|
||||
def test_approval_checked_after_validation(self):
|
||||
"""Validation runs before approval — blocked queries never reach the handler."""
|
||||
handler = MagicMock(return_value=True)
|
||||
tool = _make_tool(
|
||||
allow_dml=False,
|
||||
require_approval=True,
|
||||
approval_handler=handler,
|
||||
)
|
||||
with pytest.raises(ValueError, match="read-only mode"):
|
||||
tool._run("DROP TABLE users")
|
||||
handler.assert_not_called()
|
||||
|
||||
def test_approval_with_keyboard_interrupt(self):
|
||||
"""KeyboardInterrupt during input() rejects the query gracefully."""
|
||||
tool = _make_tool(require_approval=True)
|
||||
with patch("builtins.input", side_effect=KeyboardInterrupt):
|
||||
result = tool._run("SELECT 1")
|
||||
assert "rejected" in result.lower()
|
||||
|
||||
@@ -15870,6 +15870,12 @@
|
||||
"title": "Database URI",
|
||||
"type": "string"
|
||||
},
|
||||
"require_approval": {
|
||||
"default": false,
|
||||
"description": "When True, every query is shown to a human for approval before execution. The approval_handler callable is invoked with the SQL string and must return True to proceed. Defaults to an interactive terminal prompt.",
|
||||
"title": "Require Approval",
|
||||
"type": "boolean"
|
||||
},
|
||||
"tables": {
|
||||
"items": {
|
||||
"additionalProperties": true,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Any, cast
|
||||
@@ -32,6 +33,16 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
| type[BaseEmbeddingsProvider[Any]]
|
||||
| None
|
||||
) = Field(default=None, exclude=True)
|
||||
content_filter: Callable[[list[str]], list[str]] | None = Field(
|
||||
default=None,
|
||||
exclude=True,
|
||||
description=(
|
||||
"Optional callable that inspects and filters documents before "
|
||||
"they are indexed. Receives the full document list and must "
|
||||
"return the (possibly filtered) list to persist. Raise an "
|
||||
"exception inside the callable to abort the save entirely."
|
||||
),
|
||||
)
|
||||
_client: BaseClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -106,6 +117,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
if not documents:
|
||||
return
|
||||
|
||||
if self.content_filter is not None:
|
||||
documents = self.content_filter(documents)
|
||||
if not documents:
|
||||
return
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
@@ -187,6 +203,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
if not documents:
|
||||
return
|
||||
|
||||
if self.content_filter is not None:
|
||||
documents = self.content_filter(documents)
|
||||
if not documents:
|
||||
return
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
|
||||
@@ -193,3 +193,118 @@ def test_dimension_mismatch_error_handling(mock_get_client: MagicMock) -> None:
|
||||
|
||||
with pytest.raises(ValueError, match="Embedding dimension mismatch"):
|
||||
storage.save(["test document"])
|
||||
|
||||
|
||||
# --- content_filter tests ---
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_content_filter_removes_documents(mock_get_client: MagicMock) -> None:
|
||||
"""content_filter can drop specific documents before indexing."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
def reject_secrets(docs: list[str]) -> list[str]:
|
||||
return [d for d in docs if "SECRET" not in d]
|
||||
|
||||
storage = KnowledgeStorage(
|
||||
collection_name="filter_test", content_filter=reject_secrets
|
||||
)
|
||||
storage.save(["safe content", "contains SECRET key", "also safe"])
|
||||
|
||||
mock_client.add_documents.assert_called_once()
|
||||
added = mock_client.add_documents.call_args.kwargs["documents"]
|
||||
contents = [doc["content"] for doc in added]
|
||||
assert contents == ["safe content", "also safe"]
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_content_filter_returns_empty_skips_save(mock_get_client: MagicMock) -> None:
|
||||
"""When content_filter filters out all documents, save is skipped entirely."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(
|
||||
collection_name="empty_filter", content_filter=lambda docs: []
|
||||
)
|
||||
storage.save(["doc1", "doc2"])
|
||||
|
||||
mock_client.add_documents.assert_not_called()
|
||||
mock_client.get_or_create_collection.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_content_filter_exception_propagates(mock_get_client: MagicMock) -> None:
|
||||
"""Exceptions raised inside content_filter abort the save."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
def strict_filter(docs: list[str]) -> list[str]:
|
||||
raise ValueError("Blocked by policy")
|
||||
|
||||
storage = KnowledgeStorage(
|
||||
collection_name="strict_test", content_filter=strict_filter
|
||||
)
|
||||
with pytest.raises(ValueError, match="Blocked by policy"):
|
||||
storage.save(["some content"])
|
||||
|
||||
mock_client.add_documents.assert_not_called()
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
def test_content_filter_none_is_noop(mock_get_client: MagicMock) -> None:
|
||||
"""When content_filter is None (default), all documents are saved."""
|
||||
mock_client = MagicMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(collection_name="noop_test")
|
||||
assert storage.content_filter is None
|
||||
storage.save(["doc1", "doc2"])
|
||||
|
||||
mock_client.add_documents.assert_called_once()
|
||||
added = mock_client.add_documents.call_args.kwargs["documents"]
|
||||
assert len(added) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
async def test_content_filter_async_save(mock_get_client: MagicMock) -> None:
|
||||
"""content_filter is applied in asave() as well."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.aget_or_create_collection = AsyncMock()
|
||||
mock_client.aadd_documents = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
def only_short(docs: list[str]) -> list[str]:
|
||||
return [d for d in docs if len(d) < 20]
|
||||
|
||||
storage = KnowledgeStorage(
|
||||
collection_name="async_filter", content_filter=only_short
|
||||
)
|
||||
await storage.asave(["short", "this is a much longer document string"])
|
||||
|
||||
mock_client.aadd_documents.assert_called_once()
|
||||
added = mock_client.aadd_documents.call_args.kwargs["documents"]
|
||||
assert len(added) == 1
|
||||
assert added[0]["content"] == "short"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
async def test_content_filter_async_all_filtered(mock_get_client: MagicMock) -> None:
|
||||
"""asave() skips persistence when content_filter removes everything."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.aget_or_create_collection = AsyncMock()
|
||||
mock_client.aadd_documents = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
storage = KnowledgeStorage(
|
||||
collection_name="async_empty", content_filter=lambda docs: []
|
||||
)
|
||||
await storage.asave(["doc1"])
|
||||
|
||||
mock_client.aadd_documents.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user