mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-12 20:48:15 +00:00
feat: add GuardrailProvider interface for pre-tool-call authorization (#4877)
- Add GuardrailRequest dataclass for tool call context - Add GuardrailDecision dataclass for allow/deny verdicts - Add GuardrailProvider runtime-checkable protocol - Add enable_guardrail() adapter wiring providers into BeforeToolCallHook - Add disable() callable returned by enable_guardrail for cleanup - Support fail_closed (default) and fail_open exception handling - Export new types from crewai.hooks - Add 29 comprehensive tests covering all scenarios Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -6,6 +6,12 @@ from crewai.hooks.decorators import (
|
|||||||
before_llm_call,
|
before_llm_call,
|
||||||
before_tool_call,
|
before_tool_call,
|
||||||
)
|
)
|
||||||
|
from crewai.hooks.guardrail_provider import (
|
||||||
|
GuardrailDecision,
|
||||||
|
GuardrailProvider,
|
||||||
|
GuardrailRequest,
|
||||||
|
enable_guardrail,
|
||||||
|
)
|
||||||
from crewai.hooks.llm_hooks import (
|
from crewai.hooks.llm_hooks import (
|
||||||
LLMCallHookContext,
|
LLMCallHookContext,
|
||||||
clear_after_llm_call_hooks,
|
clear_after_llm_call_hooks,
|
||||||
@@ -74,10 +80,11 @@ def clear_all_global_hooks() -> dict[str, tuple[int, int]]:
|
|||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Context classes
|
"GuardrailDecision",
|
||||||
|
"GuardrailProvider",
|
||||||
|
"GuardrailRequest",
|
||||||
"LLMCallHookContext",
|
"LLMCallHookContext",
|
||||||
"ToolCallHookContext",
|
"ToolCallHookContext",
|
||||||
# Decorators
|
|
||||||
"after_llm_call",
|
"after_llm_call",
|
||||||
"after_tool_call",
|
"after_tool_call",
|
||||||
"before_llm_call",
|
"before_llm_call",
|
||||||
@@ -87,19 +94,16 @@ __all__ = [
|
|||||||
"clear_all_global_hooks",
|
"clear_all_global_hooks",
|
||||||
"clear_all_llm_call_hooks",
|
"clear_all_llm_call_hooks",
|
||||||
"clear_all_tool_call_hooks",
|
"clear_all_tool_call_hooks",
|
||||||
# Clear hooks
|
|
||||||
"clear_before_llm_call_hooks",
|
"clear_before_llm_call_hooks",
|
||||||
"clear_before_tool_call_hooks",
|
"clear_before_tool_call_hooks",
|
||||||
|
"enable_guardrail",
|
||||||
"get_after_llm_call_hooks",
|
"get_after_llm_call_hooks",
|
||||||
"get_after_tool_call_hooks",
|
"get_after_tool_call_hooks",
|
||||||
# Get hooks
|
|
||||||
"get_before_llm_call_hooks",
|
"get_before_llm_call_hooks",
|
||||||
"get_before_tool_call_hooks",
|
"get_before_tool_call_hooks",
|
||||||
"register_after_llm_call_hook",
|
"register_after_llm_call_hook",
|
||||||
"register_after_tool_call_hook",
|
"register_after_tool_call_hook",
|
||||||
# LLM Hook registration
|
|
||||||
"register_before_llm_call_hook",
|
"register_before_llm_call_hook",
|
||||||
# Tool Hook registration
|
|
||||||
"register_before_tool_call_hook",
|
"register_before_tool_call_hook",
|
||||||
"unregister_after_llm_call_hook",
|
"unregister_after_llm_call_hook",
|
||||||
"unregister_after_tool_call_hook",
|
"unregister_after_tool_call_hook",
|
||||||
|
|||||||
295
lib/crewai/src/crewai/hooks/guardrail_provider.py
Normal file
295
lib/crewai/src/crewai/hooks/guardrail_provider.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""GuardrailProvider interface for pre-tool-call authorization.
|
||||||
|
|
||||||
|
This module provides a standard protocol for pluggable tool-call authorization
|
||||||
|
that sits on top of CrewAI's existing BeforeToolCallHook system.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
Simple provider that blocks specific tools::
|
||||||
|
|
||||||
|
from crewai.hooks import (
|
||||||
|
GuardrailProvider,
|
||||||
|
GuardrailRequest,
|
||||||
|
GuardrailDecision,
|
||||||
|
enable_guardrail,
|
||||||
|
)
|
||||||
|
|
||||||
|
class BlockListProvider:
|
||||||
|
name = "block_list"
|
||||||
|
|
||||||
|
def __init__(self, blocked_tools: list[str]) -> None:
|
||||||
|
self.blocked_tools = blocked_tools
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
if request.tool_name in self.blocked_tools:
|
||||||
|
return GuardrailDecision(
|
||||||
|
allow=False,
|
||||||
|
reason=f"Tool '{request.tool_name}' is blocked by policy",
|
||||||
|
)
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
provider = BlockListProvider(blocked_tools=["ShellTool", "dangerous_op"])
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
# Later, to remove the guardrail:
|
||||||
|
disable()
|
||||||
|
|
||||||
|
Rate-limiting provider::
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
class RateLimitProvider:
|
||||||
|
name = "rate_limiter"
|
||||||
|
|
||||||
|
def __init__(self, max_calls: int, window_seconds: float = 60.0) -> None:
|
||||||
|
self.max_calls = max_calls
|
||||||
|
self.window_seconds = window_seconds
|
||||||
|
self._calls: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
now = time.time()
|
||||||
|
key = request.tool_name
|
||||||
|
# Remove expired entries
|
||||||
|
self._calls[key] = [
|
||||||
|
t for t in self._calls[key]
|
||||||
|
if now - t < self.window_seconds
|
||||||
|
]
|
||||||
|
if len(self._calls[key]) >= self.max_calls:
|
||||||
|
return GuardrailDecision(
|
||||||
|
allow=False,
|
||||||
|
reason=f"Rate limit exceeded for '{key}'",
|
||||||
|
)
|
||||||
|
self._calls[key].append(now)
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
Per-agent role restriction::
|
||||||
|
|
||||||
|
class RoleBasedProvider:
|
||||||
|
name = "role_based"
|
||||||
|
|
||||||
|
def __init__(self, permissions: dict[str, list[str]]) -> None:
|
||||||
|
# Maps agent role -> list of allowed tool names
|
||||||
|
self.permissions = permissions
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
role = request.agent_role
|
||||||
|
if role is None:
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
allowed = self.permissions.get(role)
|
||||||
|
if allowed is not None and request.tool_name not in allowed:
|
||||||
|
return GuardrailDecision(
|
||||||
|
allow=False,
|
||||||
|
reason=(
|
||||||
|
f"Agent '{role}' is not permitted "
|
||||||
|
f"to use '{request.tool_name}'"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
return True
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from crewai.hooks.tool_hooks import (
|
||||||
|
ToolCallHookContext,
|
||||||
|
register_before_tool_call_hook,
|
||||||
|
unregister_before_tool_call_hook,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GuardrailRequest:
|
||||||
|
"""Context passed to the provider for each tool call.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tool_name: Name of the tool being invoked.
|
||||||
|
tool_input: Dictionary of arguments passed to the tool.
|
||||||
|
agent_role: Role of the agent executing the tool (may be ``None``).
|
||||||
|
task_description: Description of the current task (may be ``None``).
|
||||||
|
crew_id: Identifier for the crew instance (may be ``None``).
|
||||||
|
timestamp: ISO 8601 timestamp of when the request was created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool_name: str
|
||||||
|
tool_input: dict[str, object]
|
||||||
|
agent_role: str | None = None
|
||||||
|
task_description: str | None = None
|
||||||
|
crew_id: str | None = None
|
||||||
|
timestamp: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GuardrailDecision:
|
||||||
|
"""Provider's allow/deny verdict for a tool call.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
allow: ``True`` to permit execution, ``False`` to block it.
|
||||||
|
reason: Human-readable explanation (surfaced to the agent when blocked).
|
||||||
|
metadata: Arbitrary provider-specific data (e.g. policy ID, audit ref).
|
||||||
|
"""
|
||||||
|
|
||||||
|
allow: bool
|
||||||
|
reason: str | None = None
|
||||||
|
metadata: dict[str, object] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class GuardrailProvider(Protocol):
|
||||||
|
"""Contract for pluggable tool-call authorization.
|
||||||
|
|
||||||
|
Any class that implements this protocol can be wired into CrewAI's
|
||||||
|
hook system via :func:`enable_guardrail` to authorize or deny
|
||||||
|
individual tool calls before they execute.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: Short identifier for logging / audit purposes.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> class MyProvider:
|
||||||
|
... name = "my_provider"
|
||||||
|
...
|
||||||
|
... def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
... if request.tool_name == "dangerous_tool":
|
||||||
|
... return GuardrailDecision(allow=False, reason="Blocked")
|
||||||
|
... return GuardrailDecision(allow=True)
|
||||||
|
...
|
||||||
|
... def health_check(self) -> bool:
|
||||||
|
... return True
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
"""Evaluate whether a tool call should proceed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Context about the pending tool call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A :class:`GuardrailDecision`. If ``allow`` is ``False``, the tool
|
||||||
|
call is blocked and ``reason`` is surfaced to the agent.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
"""Optional readiness probe.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``True`` if the provider is healthy and ready, ``False`` otherwise.
|
||||||
|
The default expectation is ``True``.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def _build_guardrail_request(context: ToolCallHookContext) -> GuardrailRequest:
|
||||||
|
"""Build a :class:`GuardrailRequest` from a :class:`ToolCallHookContext`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: The hook context for the current tool call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A populated :class:`GuardrailRequest`.
|
||||||
|
"""
|
||||||
|
agent_role: str | None = None
|
||||||
|
if context.agent is not None and hasattr(context.agent, "role"):
|
||||||
|
agent_role = context.agent.role
|
||||||
|
|
||||||
|
task_description: str | None = None
|
||||||
|
if context.task is not None and hasattr(context.task, "description"):
|
||||||
|
task_description = context.task.description
|
||||||
|
|
||||||
|
crew_id: str | None = None
|
||||||
|
if context.crew is not None and hasattr(context.crew, "id"):
|
||||||
|
crew_id = str(context.crew.id)
|
||||||
|
|
||||||
|
return GuardrailRequest(
|
||||||
|
tool_name=context.tool_name,
|
||||||
|
tool_input=context.tool_input,
|
||||||
|
agent_role=agent_role,
|
||||||
|
task_description=task_description,
|
||||||
|
crew_id=crew_id,
|
||||||
|
timestamp=datetime.datetime.now(datetime.timezone.utc).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_guardrail(
|
||||||
|
provider: GuardrailProvider,
|
||||||
|
*,
|
||||||
|
fail_closed: bool = True,
|
||||||
|
) -> Callable[[], bool]:
|
||||||
|
"""Wire a :class:`GuardrailProvider` into CrewAI's hook system.
|
||||||
|
|
||||||
|
This registers a ``BeforeToolCallHook`` that delegates authorization
|
||||||
|
decisions to the given *provider*. The returned callable can be used
|
||||||
|
to remove the hook later.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: An object satisfying the :class:`GuardrailProvider` protocol.
|
||||||
|
fail_closed: When ``True`` (the default), any exception raised by
|
||||||
|
``provider.evaluate()`` causes the tool call to be **blocked**.
|
||||||
|
When ``False``, exceptions are logged and the tool call is
|
||||||
|
**allowed** to proceed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``disable`` callable. Calling ``disable()`` unregisters the
|
||||||
|
hook and returns ``True`` if it was still registered, ``False``
|
||||||
|
otherwise.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> disable = enable_guardrail(my_provider, fail_closed=True)
|
||||||
|
>>> # ... run crews / agents ...
|
||||||
|
>>> disable() # remove the guardrail
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _hook(context: ToolCallHookContext) -> bool | None:
|
||||||
|
request = _build_guardrail_request(context)
|
||||||
|
try:
|
||||||
|
decision = provider.evaluate(request)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"GuardrailProvider '%s' raised an exception (fail_closed=%s)",
|
||||||
|
provider.name,
|
||||||
|
fail_closed,
|
||||||
|
)
|
||||||
|
return False if fail_closed else None
|
||||||
|
|
||||||
|
if not decision.allow:
|
||||||
|
logger.info(
|
||||||
|
"GuardrailProvider '%s' denied tool '%s': %s",
|
||||||
|
provider.name,
|
||||||
|
context.tool_name,
|
||||||
|
decision.reason,
|
||||||
|
)
|
||||||
|
return False # block tool execution
|
||||||
|
|
||||||
|
return None # allow tool execution
|
||||||
|
|
||||||
|
register_before_tool_call_hook(_hook)
|
||||||
|
|
||||||
|
def disable() -> bool:
|
||||||
|
"""Unregister the guardrail hook.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``True`` if the hook was found and removed, ``False`` otherwise.
|
||||||
|
"""
|
||||||
|
return unregister_before_tool_call_hook(_hook)
|
||||||
|
|
||||||
|
return disable
|
||||||
590
lib/crewai/tests/hooks/test_guardrail_provider.py
Normal file
590
lib/crewai/tests/hooks/test_guardrail_provider.py
Normal file
@@ -0,0 +1,590 @@
|
|||||||
|
"""Tests for the GuardrailProvider interface and enable_guardrail adapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.hooks.guardrail_provider import (
|
||||||
|
GuardrailDecision,
|
||||||
|
GuardrailProvider,
|
||||||
|
GuardrailRequest,
|
||||||
|
_build_guardrail_request,
|
||||||
|
enable_guardrail,
|
||||||
|
)
|
||||||
|
from crewai.hooks.tool_hooks import (
|
||||||
|
ToolCallHookContext,
|
||||||
|
get_before_tool_call_hooks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool():
|
||||||
|
"""Create a mock tool for testing."""
|
||||||
|
tool = Mock()
|
||||||
|
tool.name = "test_tool"
|
||||||
|
tool.description = "Test tool description"
|
||||||
|
return tool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent():
|
||||||
|
"""Create a mock agent for testing."""
|
||||||
|
agent = Mock()
|
||||||
|
agent.role = "Researcher"
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_task():
|
||||||
|
"""Create a mock task for testing."""
|
||||||
|
task = Mock()
|
||||||
|
task.description = "Summarize the findings"
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_crew():
|
||||||
|
"""Create a mock crew for testing."""
|
||||||
|
crew = Mock()
|
||||||
|
crew.id = "crew-123"
|
||||||
|
return crew
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_hooks():
|
||||||
|
"""Clear global hooks before and after each test."""
|
||||||
|
from crewai.hooks import tool_hooks
|
||||||
|
|
||||||
|
original_before = tool_hooks._before_tool_call_hooks.copy()
|
||||||
|
original_after = tool_hooks._after_tool_call_hooks.copy()
|
||||||
|
|
||||||
|
tool_hooks._before_tool_call_hooks.clear()
|
||||||
|
tool_hooks._after_tool_call_hooks.clear()
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
tool_hooks._before_tool_call_hooks.clear()
|
||||||
|
tool_hooks._after_tool_call_hooks.clear()
|
||||||
|
tool_hooks._before_tool_call_hooks.extend(original_before)
|
||||||
|
tool_hooks._after_tool_call_hooks.extend(original_after)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Concrete provider used across tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class AllowAllProvider:
|
||||||
|
"""A provider that allows every tool call."""
|
||||||
|
|
||||||
|
name = "allow_all"
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class BlockListProvider:
|
||||||
|
"""A provider that blocks specific tools by name."""
|
||||||
|
|
||||||
|
name = "block_list"
|
||||||
|
|
||||||
|
def __init__(self, blocked_tools: list[str]) -> None:
|
||||||
|
self.blocked_tools = blocked_tools
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
if request.tool_name in self.blocked_tools:
|
||||||
|
return GuardrailDecision(
|
||||||
|
allow=False,
|
||||||
|
reason=f"Tool '{request.tool_name}' is blocked by policy",
|
||||||
|
metadata={"policy": "block_list"},
|
||||||
|
)
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ExplodingProvider:
|
||||||
|
"""A provider that always raises an exception during evaluate."""
|
||||||
|
|
||||||
|
name = "exploding"
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
raise RuntimeError("Provider failure!")
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class RoleBasedProvider:
|
||||||
|
"""A provider that restricts tool access based on agent role."""
|
||||||
|
|
||||||
|
name = "role_based"
|
||||||
|
|
||||||
|
def __init__(self, permissions: dict[str, list[str]]) -> None:
|
||||||
|
self.permissions = permissions
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
role = request.agent_role
|
||||||
|
if role is None:
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
allowed = self.permissions.get(role)
|
||||||
|
if allowed is not None and request.tool_name not in allowed:
|
||||||
|
return GuardrailDecision(
|
||||||
|
allow=False,
|
||||||
|
reason=f"Agent '{role}' is not permitted to use '{request.tool_name}'",
|
||||||
|
)
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GuardrailRequest tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGuardrailRequest:
|
||||||
|
"""Test GuardrailRequest construction and defaults."""
|
||||||
|
|
||||||
|
def test_required_fields(self):
|
||||||
|
req = GuardrailRequest(tool_name="search", tool_input={"q": "hello"})
|
||||||
|
assert req.tool_name == "search"
|
||||||
|
assert req.tool_input == {"q": "hello"}
|
||||||
|
|
||||||
|
def test_optional_fields_default_to_none_or_empty(self):
|
||||||
|
req = GuardrailRequest(tool_name="search", tool_input={})
|
||||||
|
assert req.agent_role is None
|
||||||
|
assert req.task_description is None
|
||||||
|
assert req.crew_id is None
|
||||||
|
assert req.timestamp == ""
|
||||||
|
|
||||||
|
def test_all_fields_populated(self):
|
||||||
|
req = GuardrailRequest(
|
||||||
|
tool_name="write_file",
|
||||||
|
tool_input={"path": "/tmp/x"},
|
||||||
|
agent_role="Developer",
|
||||||
|
task_description="Write config",
|
||||||
|
crew_id="crew-42",
|
||||||
|
timestamp="2025-01-01T00:00:00+00:00",
|
||||||
|
)
|
||||||
|
assert req.tool_name == "write_file"
|
||||||
|
assert req.tool_input == {"path": "/tmp/x"}
|
||||||
|
assert req.agent_role == "Developer"
|
||||||
|
assert req.task_description == "Write config"
|
||||||
|
assert req.crew_id == "crew-42"
|
||||||
|
assert req.timestamp == "2025-01-01T00:00:00+00:00"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GuardrailDecision tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGuardrailDecision:
|
||||||
|
"""Test GuardrailDecision construction and defaults."""
|
||||||
|
|
||||||
|
def test_allow_decision(self):
|
||||||
|
dec = GuardrailDecision(allow=True)
|
||||||
|
assert dec.allow is True
|
||||||
|
assert dec.reason is None
|
||||||
|
assert dec.metadata == {}
|
||||||
|
|
||||||
|
def test_deny_decision_with_reason(self):
|
||||||
|
dec = GuardrailDecision(allow=False, reason="Blocked by policy")
|
||||||
|
assert dec.allow is False
|
||||||
|
assert dec.reason == "Blocked by policy"
|
||||||
|
|
||||||
|
def test_decision_with_metadata(self):
|
||||||
|
dec = GuardrailDecision(
|
||||||
|
allow=False,
|
||||||
|
reason="Denied",
|
||||||
|
metadata={"policy_id": "P-001", "audit": True},
|
||||||
|
)
|
||||||
|
assert dec.metadata == {"policy_id": "P-001", "audit": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GuardrailProvider protocol tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestGuardrailProviderProtocol:
|
||||||
|
"""Test that the runtime_checkable protocol works correctly."""
|
||||||
|
|
||||||
|
def test_allow_all_provider_is_guardrail_provider(self):
|
||||||
|
assert isinstance(AllowAllProvider(), GuardrailProvider)
|
||||||
|
|
||||||
|
def test_block_list_provider_is_guardrail_provider(self):
|
||||||
|
assert isinstance(BlockListProvider(blocked_tools=[]), GuardrailProvider)
|
||||||
|
|
||||||
|
def test_exploding_provider_is_guardrail_provider(self):
|
||||||
|
assert isinstance(ExplodingProvider(), GuardrailProvider)
|
||||||
|
|
||||||
|
def test_role_based_provider_is_guardrail_provider(self):
|
||||||
|
assert isinstance(
|
||||||
|
RoleBasedProvider(permissions={}), GuardrailProvider
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_plain_object_is_not_guardrail_provider(self):
|
||||||
|
"""An object without evaluate/health_check is not a GuardrailProvider."""
|
||||||
|
assert not isinstance(object(), GuardrailProvider)
|
||||||
|
|
||||||
|
def test_partial_implementation_is_not_guardrail_provider(self):
|
||||||
|
"""An object with only evaluate but no name/health_check is not a provider."""
|
||||||
|
|
||||||
|
class Incomplete:
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
|
||||||
|
assert not isinstance(Incomplete(), GuardrailProvider)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _build_guardrail_request tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestBuildGuardrailRequest:
|
||||||
|
"""Test the internal helper that converts ToolCallHookContext to GuardrailRequest."""
|
||||||
|
|
||||||
|
def test_full_context(self, mock_tool, mock_agent, mock_task, mock_crew):
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="search",
|
||||||
|
tool_input={"query": "AI"},
|
||||||
|
tool=mock_tool,
|
||||||
|
agent=mock_agent,
|
||||||
|
task=mock_task,
|
||||||
|
crew=mock_crew,
|
||||||
|
)
|
||||||
|
req = _build_guardrail_request(context)
|
||||||
|
|
||||||
|
assert req.tool_name == "search"
|
||||||
|
assert req.tool_input == {"query": "AI"}
|
||||||
|
assert req.agent_role == "Researcher"
|
||||||
|
assert req.task_description == "Summarize the findings"
|
||||||
|
assert req.crew_id == "crew-123"
|
||||||
|
assert req.timestamp != "" # should be populated
|
||||||
|
|
||||||
|
def test_minimal_context(self, mock_tool):
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="noop",
|
||||||
|
tool_input={},
|
||||||
|
tool=mock_tool,
|
||||||
|
)
|
||||||
|
req = _build_guardrail_request(context)
|
||||||
|
|
||||||
|
assert req.tool_name == "noop"
|
||||||
|
assert req.tool_input == {}
|
||||||
|
assert req.agent_role is None
|
||||||
|
assert req.task_description is None
|
||||||
|
assert req.crew_id is None
|
||||||
|
assert req.timestamp != ""
|
||||||
|
|
||||||
|
def test_agent_without_role_attribute(self, mock_tool):
|
||||||
|
"""Agent-like objects without a role attribute should yield None."""
|
||||||
|
agent_no_role = Mock(spec=[]) # no attributes at all
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="tool",
|
||||||
|
tool_input={},
|
||||||
|
tool=mock_tool,
|
||||||
|
agent=agent_no_role,
|
||||||
|
)
|
||||||
|
req = _build_guardrail_request(context)
|
||||||
|
assert req.agent_role is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enable_guardrail tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEnableGuardrail:
|
||||||
|
"""Test the enable_guardrail adapter function."""
|
||||||
|
|
||||||
|
def test_enable_registers_a_before_hook(self):
|
||||||
|
provider = AllowAllProvider()
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
assert len(hooks) == 1
|
||||||
|
|
||||||
|
disable()
|
||||||
|
|
||||||
|
def test_disable_removes_the_hook(self):
|
||||||
|
provider = AllowAllProvider()
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
assert len(get_before_tool_call_hooks()) == 1
|
||||||
|
|
||||||
|
result = disable()
|
||||||
|
assert result is True
|
||||||
|
assert len(get_before_tool_call_hooks()) == 0
|
||||||
|
|
||||||
|
def test_disable_returns_false_when_already_removed(self):
|
||||||
|
provider = AllowAllProvider()
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
disable() # first removal
|
||||||
|
result = disable() # second removal – already gone
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_allow_all_provider_permits_tool_call(self, mock_tool):
|
||||||
|
provider = AllowAllProvider()
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="any_tool",
|
||||||
|
tool_input={"x": 1},
|
||||||
|
tool=mock_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
result = hooks[0](context)
|
||||||
|
assert result is None # None means allow
|
||||||
|
|
||||||
|
disable()
|
||||||
|
|
||||||
|
def test_block_list_provider_denies_blocked_tool(self, mock_tool):
|
||||||
|
provider = BlockListProvider(blocked_tools=["ShellTool"])
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="ShellTool",
|
||||||
|
tool_input={"cmd": "rm -rf /"},
|
||||||
|
tool=mock_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
result = hooks[0](context)
|
||||||
|
assert result is False # blocked
|
||||||
|
|
||||||
|
disable()
|
||||||
|
|
||||||
|
def test_block_list_provider_allows_non_blocked_tool(self, mock_tool):
|
||||||
|
provider = BlockListProvider(blocked_tools=["ShellTool"])
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="SearchTool",
|
||||||
|
tool_input={"q": "hello"},
|
||||||
|
tool=mock_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
result = hooks[0](context)
|
||||||
|
assert result is None # allowed
|
||||||
|
|
||||||
|
disable()
|
||||||
|
|
||||||
|
def test_role_based_provider_blocks_unauthorized_agent(
|
||||||
|
self, mock_tool, mock_agent
|
||||||
|
):
|
||||||
|
provider = RoleBasedProvider(
|
||||||
|
permissions={"Researcher": ["SearchTool", "ReadFileTool"]}
|
||||||
|
)
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="ShellTool",
|
||||||
|
tool_input={},
|
||||||
|
tool=mock_tool,
|
||||||
|
agent=mock_agent, # role = "Researcher"
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
result = hooks[0](context)
|
||||||
|
assert result is False # Researcher can't use ShellTool
|
||||||
|
|
||||||
|
disable()
|
||||||
|
|
||||||
|
def test_role_based_provider_allows_authorized_agent(
|
||||||
|
self, mock_tool, mock_agent
|
||||||
|
):
|
||||||
|
provider = RoleBasedProvider(
|
||||||
|
permissions={"Researcher": ["SearchTool"]}
|
||||||
|
)
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="SearchTool",
|
||||||
|
tool_input={},
|
||||||
|
tool=mock_tool,
|
||||||
|
agent=mock_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
result = hooks[0](context)
|
||||||
|
assert result is None # allowed
|
||||||
|
|
||||||
|
disable()
|
||||||
|
|
||||||
|
def test_fail_closed_blocks_on_exception(self, mock_tool):
|
||||||
|
"""When fail_closed=True (default), provider exceptions block the tool."""
|
||||||
|
provider = ExplodingProvider()
|
||||||
|
disable = enable_guardrail(provider, fail_closed=True)
|
||||||
|
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="any_tool",
|
||||||
|
tool_input={},
|
||||||
|
tool=mock_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
result = hooks[0](context)
|
||||||
|
assert result is False # blocked due to exception
|
||||||
|
|
||||||
|
disable()
|
||||||
|
|
||||||
|
def test_fail_open_allows_on_exception(self, mock_tool):
|
||||||
|
"""When fail_closed=False, provider exceptions allow the tool."""
|
||||||
|
provider = ExplodingProvider()
|
||||||
|
disable = enable_guardrail(provider, fail_closed=False)
|
||||||
|
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="any_tool",
|
||||||
|
tool_input={},
|
||||||
|
tool=mock_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
result = hooks[0](context)
|
||||||
|
assert result is None # allowed despite exception
|
||||||
|
|
||||||
|
disable()
|
||||||
|
|
||||||
|
def test_multiple_providers_all_must_allow(self, mock_tool):
|
||||||
|
"""When multiple providers are enabled, all must allow for the tool to proceed."""
|
||||||
|
provider1 = AllowAllProvider()
|
||||||
|
provider2 = BlockListProvider(blocked_tools=["DangerousTool"])
|
||||||
|
|
||||||
|
disable1 = enable_guardrail(provider1)
|
||||||
|
disable2 = enable_guardrail(provider2)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
assert len(hooks) == 2
|
||||||
|
|
||||||
|
# Safe tool – both allow
|
||||||
|
context_safe = ToolCallHookContext(
|
||||||
|
tool_name="SafeTool",
|
||||||
|
tool_input={},
|
||||||
|
tool=mock_tool,
|
||||||
|
)
|
||||||
|
results = [h(context_safe) for h in hooks]
|
||||||
|
assert all(r is None for r in results)
|
||||||
|
|
||||||
|
# Dangerous tool – first allows, second blocks
|
||||||
|
context_danger = ToolCallHookContext(
|
||||||
|
tool_name="DangerousTool",
|
||||||
|
tool_input={},
|
||||||
|
tool=mock_tool,
|
||||||
|
)
|
||||||
|
blocked = False
|
||||||
|
for hook in hooks:
|
||||||
|
result = hook(context_danger)
|
||||||
|
if result is False:
|
||||||
|
blocked = True
|
||||||
|
break
|
||||||
|
assert blocked is True
|
||||||
|
|
||||||
|
disable1()
|
||||||
|
disable2()
|
||||||
|
|
||||||
|
def test_guardrail_request_timestamp_is_set(self, mock_tool):
|
||||||
|
"""The hook should populate the timestamp in the GuardrailRequest."""
|
||||||
|
received_requests: list[GuardrailRequest] = []
|
||||||
|
|
||||||
|
class SpyProvider:
|
||||||
|
name = "spy"
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
received_requests.append(request)
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
provider = SpyProvider()
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="tool",
|
||||||
|
tool_input={"key": "val"},
|
||||||
|
tool=mock_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
hooks[0](context)
|
||||||
|
|
||||||
|
assert len(received_requests) == 1
|
||||||
|
assert received_requests[0].timestamp != ""
|
||||||
|
# Should be a valid ISO 8601 string
|
||||||
|
assert "T" in received_requests[0].timestamp
|
||||||
|
|
||||||
|
disable()
|
||||||
|
|
||||||
|
def test_guardrail_context_fields_passed_through(
|
||||||
|
self, mock_tool, mock_agent, mock_task, mock_crew
|
||||||
|
):
|
||||||
|
"""Verify that agent_role, task_description, crew_id are forwarded."""
|
||||||
|
received_requests: list[GuardrailRequest] = []
|
||||||
|
|
||||||
|
class SpyProvider:
|
||||||
|
name = "spy"
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
received_requests.append(request)
|
||||||
|
return GuardrailDecision(allow=True)
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
provider = SpyProvider()
|
||||||
|
disable = enable_guardrail(provider)
|
||||||
|
|
||||||
|
context = ToolCallHookContext(
|
||||||
|
tool_name="search",
|
||||||
|
tool_input={"q": "test"},
|
||||||
|
tool=mock_tool,
|
||||||
|
agent=mock_agent,
|
||||||
|
task=mock_task,
|
||||||
|
crew=mock_crew,
|
||||||
|
)
|
||||||
|
|
||||||
|
hooks = get_before_tool_call_hooks()
|
||||||
|
hooks[0](context)
|
||||||
|
|
||||||
|
req = received_requests[0]
|
||||||
|
assert req.tool_name == "search"
|
||||||
|
assert req.tool_input == {"q": "test"}
|
||||||
|
assert req.agent_role == "Researcher"
|
||||||
|
assert req.task_description == "Summarize the findings"
|
||||||
|
assert req.crew_id == "crew-123"
|
||||||
|
|
||||||
|
disable()
|
||||||
|
|
||||||
|
def test_decision_metadata_is_accessible(self, mock_tool):
|
||||||
|
"""Provider metadata in the decision can be used for auditing."""
|
||||||
|
|
||||||
|
class AuditProvider:
|
||||||
|
name = "audit"
|
||||||
|
|
||||||
|
def evaluate(self, request: GuardrailRequest) -> GuardrailDecision:
|
||||||
|
return GuardrailDecision(
|
||||||
|
allow=True,
|
||||||
|
metadata={"trace_id": "abc-123", "evaluated_at": request.timestamp},
|
||||||
|
)
|
||||||
|
|
||||||
|
def health_check(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
provider = AuditProvider()
|
||||||
|
# Just verify the provider works; metadata is returned but
|
||||||
|
# not directly exposed by the hook (it's for provider-side use)
|
||||||
|
req = GuardrailRequest(tool_name="tool", tool_input={})
|
||||||
|
decision = provider.evaluate(req)
|
||||||
|
assert decision.metadata["trace_id"] == "abc-123"
|
||||||
Reference in New Issue
Block a user