From d5100a54c77ca88b8cf51062b12c080af65206c0 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 14 Mar 2026 18:18:37 +0000 Subject: [PATCH] feat: add GuardrailProvider interface for pre-tool-call authorization (#4877) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add GuardrailRequest dataclass for tool call context - Add GuardrailDecision dataclass for allow/deny verdicts - Add GuardrailProvider runtime-checkable protocol - Add enable_guardrail() adapter wiring providers into BeforeToolCallHook - Add disable() callable returned by enable_guardrail for cleanup - Support fail_closed (default) and fail_open exception handling - Export new types from crewai.hooks - Add 29 comprehensive tests covering all scenarios Co-Authored-By: João --- lib/crewai/src/crewai/hooks/__init__.py | 16 +- .../src/crewai/hooks/guardrail_provider.py | 295 +++++++++ .../tests/hooks/test_guardrail_provider.py | 590 ++++++++++++++++++ 3 files changed, 895 insertions(+), 6 deletions(-) create mode 100644 lib/crewai/src/crewai/hooks/guardrail_provider.py create mode 100644 lib/crewai/tests/hooks/test_guardrail_provider.py diff --git a/lib/crewai/src/crewai/hooks/__init__.py b/lib/crewai/src/crewai/hooks/__init__.py index d3681ffe1..32b75fd60 100644 --- a/lib/crewai/src/crewai/hooks/__init__.py +++ b/lib/crewai/src/crewai/hooks/__init__.py @@ -6,6 +6,12 @@ from crewai.hooks.decorators import ( before_llm_call, before_tool_call, ) +from crewai.hooks.guardrail_provider import ( + GuardrailDecision, + GuardrailProvider, + GuardrailRequest, + enable_guardrail, +) from crewai.hooks.llm_hooks import ( LLMCallHookContext, clear_after_llm_call_hooks, @@ -74,10 +80,11 @@ def clear_all_global_hooks() -> dict[str, tuple[int, int]]: __all__ = [ - # Context classes + "GuardrailDecision", + "GuardrailProvider", + "GuardrailRequest", "LLMCallHookContext", "ToolCallHookContext", - # Decorators "after_llm_call", "after_tool_call", "before_llm_call", @@ -87,19 +94,16 @@ __all__ = [ "clear_all_global_hooks", "clear_all_llm_call_hooks", "clear_all_tool_call_hooks", - # Clear hooks "clear_before_llm_call_hooks", "clear_before_tool_call_hooks", + "enable_guardrail", "get_after_llm_call_hooks", "get_after_tool_call_hooks", - # Get hooks "get_before_llm_call_hooks", "get_before_tool_call_hooks", "register_after_llm_call_hook", "register_after_tool_call_hook", - # LLM Hook registration "register_before_llm_call_hook", - # Tool Hook registration "register_before_tool_call_hook", "unregister_after_llm_call_hook", "unregister_after_tool_call_hook", diff --git a/lib/crewai/src/crewai/hooks/guardrail_provider.py b/lib/crewai/src/crewai/hooks/guardrail_provider.py new file mode 100644 index 000000000..3dab350d8 --- /dev/null +++ b/lib/crewai/src/crewai/hooks/guardrail_provider.py @@ -0,0 +1,295 @@ +"""GuardrailProvider interface for pre-tool-call authorization. + +This module provides a standard protocol for pluggable tool-call authorization +that sits on top of CrewAI's existing BeforeToolCallHook system. + +Usage: + Simple provider that blocks specific tools:: + + from crewai.hooks import ( + GuardrailProvider, + GuardrailRequest, + GuardrailDecision, + enable_guardrail, + ) + + class BlockListProvider: + name = "block_list" + + def __init__(self, blocked_tools: list[str]) -> None: + self.blocked_tools = blocked_tools + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + if request.tool_name in self.blocked_tools: + return GuardrailDecision( + allow=False, + reason=f"Tool '{request.tool_name}' is blocked by policy", + ) + return GuardrailDecision(allow=True) + + def health_check(self) -> bool: + return True + + provider = BlockListProvider(blocked_tools=["ShellTool", "dangerous_op"]) + disable = enable_guardrail(provider) + + # Later, to remove the guardrail: + disable() + + Rate-limiting provider:: + + import time + from collections import defaultdict + + class RateLimitProvider: + name = "rate_limiter" + + def __init__(self, max_calls: int, window_seconds: float = 60.0) -> None: + self.max_calls = max_calls + self.window_seconds = window_seconds + self._calls: dict[str, list[float]] = defaultdict(list) + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + now = time.time() + key = request.tool_name + # Remove expired entries + self._calls[key] = [ + t for t in self._calls[key] + if now - t < self.window_seconds + ] + if len(self._calls[key]) >= self.max_calls: + return GuardrailDecision( + allow=False, + reason=f"Rate limit exceeded for '{key}'", + ) + self._calls[key].append(now) + return GuardrailDecision(allow=True) + + def health_check(self) -> bool: + return True + + Per-agent role restriction:: + + class RoleBasedProvider: + name = "role_based" + + def __init__(self, permissions: dict[str, list[str]]) -> None: + # Maps agent role -> list of allowed tool names + self.permissions = permissions + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + role = request.agent_role + if role is None: + return GuardrailDecision(allow=True) + allowed = self.permissions.get(role) + if allowed is not None and request.tool_name not in allowed: + return GuardrailDecision( + allow=False, + reason=( + f"Agent '{role}' is not permitted " + f"to use '{request.tool_name}'" + ), + ) + return GuardrailDecision(allow=True) + + def health_check(self) -> bool: + return True +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +import datetime +import logging +from typing import Protocol, runtime_checkable + +from crewai.hooks.tool_hooks import ( + ToolCallHookContext, + register_before_tool_call_hook, + unregister_before_tool_call_hook, +) + + +logger = logging.getLogger(__name__) + + +@dataclass +class GuardrailRequest: + """Context passed to the provider for each tool call. + + Attributes: + tool_name: Name of the tool being invoked. + tool_input: Dictionary of arguments passed to the tool. + agent_role: Role of the agent executing the tool (may be ``None``). + task_description: Description of the current task (may be ``None``). + crew_id: Identifier for the crew instance (may be ``None``). + timestamp: ISO 8601 timestamp of when the request was created. + """ + + tool_name: str + tool_input: dict[str, object] + agent_role: str | None = None + task_description: str | None = None + crew_id: str | None = None + timestamp: str = "" + + +@dataclass +class GuardrailDecision: + """Provider's allow/deny verdict for a tool call. + + Attributes: + allow: ``True`` to permit execution, ``False`` to block it. + reason: Human-readable explanation (surfaced to the agent when blocked). + metadata: Arbitrary provider-specific data (e.g. policy ID, audit ref). + """ + + allow: bool + reason: str | None = None + metadata: dict[str, object] = field(default_factory=dict) + + +@runtime_checkable +class GuardrailProvider(Protocol): + """Contract for pluggable tool-call authorization. + + Any class that implements this protocol can be wired into CrewAI's + hook system via :func:`enable_guardrail` to authorize or deny + individual tool calls before they execute. + + Attributes: + name: Short identifier for logging / audit purposes. + + Example: + >>> class MyProvider: + ... name = "my_provider" + ... + ... def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + ... if request.tool_name == "dangerous_tool": + ... return GuardrailDecision(allow=False, reason="Blocked") + ... return GuardrailDecision(allow=True) + ... + ... def health_check(self) -> bool: + ... return True + """ + + name: str + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + """Evaluate whether a tool call should proceed. + + Args: + request: Context about the pending tool call. + + Returns: + A :class:`GuardrailDecision`. If ``allow`` is ``False``, the tool + call is blocked and ``reason`` is surfaced to the agent. + """ + ... + + def health_check(self) -> bool: + """Optional readiness probe. + + Returns: + ``True`` if the provider is healthy and ready, ``False`` otherwise. + The default expectation is ``True``. + """ + ... + + +def _build_guardrail_request(context: ToolCallHookContext) -> GuardrailRequest: + """Build a :class:`GuardrailRequest` from a :class:`ToolCallHookContext`. + + Args: + context: The hook context for the current tool call. + + Returns: + A populated :class:`GuardrailRequest`. + """ + agent_role: str | None = None + if context.agent is not None and hasattr(context.agent, "role"): + agent_role = context.agent.role + + task_description: str | None = None + if context.task is not None and hasattr(context.task, "description"): + task_description = context.task.description + + crew_id: str | None = None + if context.crew is not None and hasattr(context.crew, "id"): + crew_id = str(context.crew.id) + + return GuardrailRequest( + tool_name=context.tool_name, + tool_input=context.tool_input, + agent_role=agent_role, + task_description=task_description, + crew_id=crew_id, + timestamp=datetime.datetime.now(datetime.timezone.utc).isoformat(), + ) + + +def enable_guardrail( + provider: GuardrailProvider, + *, + fail_closed: bool = True, +) -> Callable[[], bool]: + """Wire a :class:`GuardrailProvider` into CrewAI's hook system. + + This registers a ``BeforeToolCallHook`` that delegates authorization + decisions to the given *provider*. The returned callable can be used + to remove the hook later. + + Args: + provider: An object satisfying the :class:`GuardrailProvider` protocol. + fail_closed: When ``True`` (the default), any exception raised by + ``provider.evaluate()`` causes the tool call to be **blocked**. + When ``False``, exceptions are logged and the tool call is + **allowed** to proceed. + + Returns: + A ``disable`` callable. Calling ``disable()`` unregisters the + hook and returns ``True`` if it was still registered, ``False`` + otherwise. + + Example: + >>> disable = enable_guardrail(my_provider, fail_closed=True) + >>> # ... run crews / agents ... + >>> disable() # remove the guardrail + True + """ + + def _hook(context: ToolCallHookContext) -> bool | None: + request = _build_guardrail_request(context) + try: + decision = provider.evaluate(request) + except Exception: + logger.exception( + "GuardrailProvider '%s' raised an exception (fail_closed=%s)", + provider.name, + fail_closed, + ) + return False if fail_closed else None + + if not decision.allow: + logger.info( + "GuardrailProvider '%s' denied tool '%s': %s", + provider.name, + context.tool_name, + decision.reason, + ) + return False # block tool execution + + return None # allow tool execution + + register_before_tool_call_hook(_hook) + + def disable() -> bool: + """Unregister the guardrail hook. + + Returns: + ``True`` if the hook was found and removed, ``False`` otherwise. + """ + return unregister_before_tool_call_hook(_hook) + + return disable diff --git a/lib/crewai/tests/hooks/test_guardrail_provider.py b/lib/crewai/tests/hooks/test_guardrail_provider.py new file mode 100644 index 000000000..36543f2c5 --- /dev/null +++ b/lib/crewai/tests/hooks/test_guardrail_provider.py @@ -0,0 +1,590 @@ +"""Tests for the GuardrailProvider interface and enable_guardrail adapter.""" + +from __future__ import annotations + +from unittest.mock import Mock + +import pytest + +from crewai.hooks.guardrail_provider import ( + GuardrailDecision, + GuardrailProvider, + GuardrailRequest, + _build_guardrail_request, + enable_guardrail, +) +from crewai.hooks.tool_hooks import ( + ToolCallHookContext, + get_before_tool_call_hooks, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_tool(): + """Create a mock tool for testing.""" + tool = Mock() + tool.name = "test_tool" + tool.description = "Test tool description" + return tool + + +@pytest.fixture +def mock_agent(): + """Create a mock agent for testing.""" + agent = Mock() + agent.role = "Researcher" + return agent + + +@pytest.fixture +def mock_task(): + """Create a mock task for testing.""" + task = Mock() + task.description = "Summarize the findings" + return task + + +@pytest.fixture +def mock_crew(): + """Create a mock crew for testing.""" + crew = Mock() + crew.id = "crew-123" + return crew + + +@pytest.fixture(autouse=True) +def clear_hooks(): + """Clear global hooks before and after each test.""" + from crewai.hooks import tool_hooks + + original_before = tool_hooks._before_tool_call_hooks.copy() + original_after = tool_hooks._after_tool_call_hooks.copy() + + tool_hooks._before_tool_call_hooks.clear() + tool_hooks._after_tool_call_hooks.clear() + + yield + + tool_hooks._before_tool_call_hooks.clear() + tool_hooks._after_tool_call_hooks.clear() + tool_hooks._before_tool_call_hooks.extend(original_before) + tool_hooks._after_tool_call_hooks.extend(original_after) + + +# --------------------------------------------------------------------------- +# Concrete provider used across tests +# --------------------------------------------------------------------------- + +class AllowAllProvider: + """A provider that allows every tool call.""" + + name = "allow_all" + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + return GuardrailDecision(allow=True) + + def health_check(self) -> bool: + return True + + +class BlockListProvider: + """A provider that blocks specific tools by name.""" + + name = "block_list" + + def __init__(self, blocked_tools: list[str]) -> None: + self.blocked_tools = blocked_tools + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + if request.tool_name in self.blocked_tools: + return GuardrailDecision( + allow=False, + reason=f"Tool '{request.tool_name}' is blocked by policy", + metadata={"policy": "block_list"}, + ) + return GuardrailDecision(allow=True) + + def health_check(self) -> bool: + return True + + +class ExplodingProvider: + """A provider that always raises an exception during evaluate.""" + + name = "exploding" + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + raise RuntimeError("Provider failure!") + + def health_check(self) -> bool: + return False + + +class RoleBasedProvider: + """A provider that restricts tool access based on agent role.""" + + name = "role_based" + + def __init__(self, permissions: dict[str, list[str]]) -> None: + self.permissions = permissions + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + role = request.agent_role + if role is None: + return GuardrailDecision(allow=True) + allowed = self.permissions.get(role) + if allowed is not None and request.tool_name not in allowed: + return GuardrailDecision( + allow=False, + reason=f"Agent '{role}' is not permitted to use '{request.tool_name}'", + ) + return GuardrailDecision(allow=True) + + def health_check(self) -> bool: + return True + + +# --------------------------------------------------------------------------- +# GuardrailRequest tests +# --------------------------------------------------------------------------- + +class TestGuardrailRequest: + """Test GuardrailRequest construction and defaults.""" + + def test_required_fields(self): + req = GuardrailRequest(tool_name="search", tool_input={"q": "hello"}) + assert req.tool_name == "search" + assert req.tool_input == {"q": "hello"} + + def test_optional_fields_default_to_none_or_empty(self): + req = GuardrailRequest(tool_name="search", tool_input={}) + assert req.agent_role is None + assert req.task_description is None + assert req.crew_id is None + assert req.timestamp == "" + + def test_all_fields_populated(self): + req = GuardrailRequest( + tool_name="write_file", + tool_input={"path": "/tmp/x"}, + agent_role="Developer", + task_description="Write config", + crew_id="crew-42", + timestamp="2025-01-01T00:00:00+00:00", + ) + assert req.tool_name == "write_file" + assert req.tool_input == {"path": "/tmp/x"} + assert req.agent_role == "Developer" + assert req.task_description == "Write config" + assert req.crew_id == "crew-42" + assert req.timestamp == "2025-01-01T00:00:00+00:00" + + +# --------------------------------------------------------------------------- +# GuardrailDecision tests +# --------------------------------------------------------------------------- + +class TestGuardrailDecision: + """Test GuardrailDecision construction and defaults.""" + + def test_allow_decision(self): + dec = GuardrailDecision(allow=True) + assert dec.allow is True + assert dec.reason is None + assert dec.metadata == {} + + def test_deny_decision_with_reason(self): + dec = GuardrailDecision(allow=False, reason="Blocked by policy") + assert dec.allow is False + assert dec.reason == "Blocked by policy" + + def test_decision_with_metadata(self): + dec = GuardrailDecision( + allow=False, + reason="Denied", + metadata={"policy_id": "P-001", "audit": True}, + ) + assert dec.metadata == {"policy_id": "P-001", "audit": True} + + +# --------------------------------------------------------------------------- +# GuardrailProvider protocol tests +# --------------------------------------------------------------------------- + +class TestGuardrailProviderProtocol: + """Test that the runtime_checkable protocol works correctly.""" + + def test_allow_all_provider_is_guardrail_provider(self): + assert isinstance(AllowAllProvider(), GuardrailProvider) + + def test_block_list_provider_is_guardrail_provider(self): + assert isinstance(BlockListProvider(blocked_tools=[]), GuardrailProvider) + + def test_exploding_provider_is_guardrail_provider(self): + assert isinstance(ExplodingProvider(), GuardrailProvider) + + def test_role_based_provider_is_guardrail_provider(self): + assert isinstance( + RoleBasedProvider(permissions={}), GuardrailProvider + ) + + def test_plain_object_is_not_guardrail_provider(self): + """An object without evaluate/health_check is not a GuardrailProvider.""" + assert not isinstance(object(), GuardrailProvider) + + def test_partial_implementation_is_not_guardrail_provider(self): + """An object with only evaluate but no name/health_check is not a provider.""" + + class Incomplete: + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + return GuardrailDecision(allow=True) + + assert not isinstance(Incomplete(), GuardrailProvider) + + +# --------------------------------------------------------------------------- +# _build_guardrail_request tests +# --------------------------------------------------------------------------- + +class TestBuildGuardrailRequest: + """Test the internal helper that converts ToolCallHookContext to GuardrailRequest.""" + + def test_full_context(self, mock_tool, mock_agent, mock_task, mock_crew): + context = ToolCallHookContext( + tool_name="search", + tool_input={"query": "AI"}, + tool=mock_tool, + agent=mock_agent, + task=mock_task, + crew=mock_crew, + ) + req = _build_guardrail_request(context) + + assert req.tool_name == "search" + assert req.tool_input == {"query": "AI"} + assert req.agent_role == "Researcher" + assert req.task_description == "Summarize the findings" + assert req.crew_id == "crew-123" + assert req.timestamp != "" # should be populated + + def test_minimal_context(self, mock_tool): + context = ToolCallHookContext( + tool_name="noop", + tool_input={}, + tool=mock_tool, + ) + req = _build_guardrail_request(context) + + assert req.tool_name == "noop" + assert req.tool_input == {} + assert req.agent_role is None + assert req.task_description is None + assert req.crew_id is None + assert req.timestamp != "" + + def test_agent_without_role_attribute(self, mock_tool): + """Agent-like objects without a role attribute should yield None.""" + agent_no_role = Mock(spec=[]) # no attributes at all + context = ToolCallHookContext( + tool_name="tool", + tool_input={}, + tool=mock_tool, + agent=agent_no_role, + ) + req = _build_guardrail_request(context) + assert req.agent_role is None + + +# --------------------------------------------------------------------------- +# enable_guardrail tests +# --------------------------------------------------------------------------- + +class TestEnableGuardrail: + """Test the enable_guardrail adapter function.""" + + def test_enable_registers_a_before_hook(self): + provider = AllowAllProvider() + disable = enable_guardrail(provider) + + hooks = get_before_tool_call_hooks() + assert len(hooks) == 1 + + disable() + + def test_disable_removes_the_hook(self): + provider = AllowAllProvider() + disable = enable_guardrail(provider) + + assert len(get_before_tool_call_hooks()) == 1 + + result = disable() + assert result is True + assert len(get_before_tool_call_hooks()) == 0 + + def test_disable_returns_false_when_already_removed(self): + provider = AllowAllProvider() + disable = enable_guardrail(provider) + + disable() # first removal + result = disable() # second removal – already gone + assert result is False + + def test_allow_all_provider_permits_tool_call(self, mock_tool): + provider = AllowAllProvider() + disable = enable_guardrail(provider) + + context = ToolCallHookContext( + tool_name="any_tool", + tool_input={"x": 1}, + tool=mock_tool, + ) + + hooks = get_before_tool_call_hooks() + result = hooks[0](context) + assert result is None # None means allow + + disable() + + def test_block_list_provider_denies_blocked_tool(self, mock_tool): + provider = BlockListProvider(blocked_tools=["ShellTool"]) + disable = enable_guardrail(provider) + + context = ToolCallHookContext( + tool_name="ShellTool", + tool_input={"cmd": "rm -rf /"}, + tool=mock_tool, + ) + + hooks = get_before_tool_call_hooks() + result = hooks[0](context) + assert result is False # blocked + + disable() + + def test_block_list_provider_allows_non_blocked_tool(self, mock_tool): + provider = BlockListProvider(blocked_tools=["ShellTool"]) + disable = enable_guardrail(provider) + + context = ToolCallHookContext( + tool_name="SearchTool", + tool_input={"q": "hello"}, + tool=mock_tool, + ) + + hooks = get_before_tool_call_hooks() + result = hooks[0](context) + assert result is None # allowed + + disable() + + def test_role_based_provider_blocks_unauthorized_agent( + self, mock_tool, mock_agent + ): + provider = RoleBasedProvider( + permissions={"Researcher": ["SearchTool", "ReadFileTool"]} + ) + disable = enable_guardrail(provider) + + context = ToolCallHookContext( + tool_name="ShellTool", + tool_input={}, + tool=mock_tool, + agent=mock_agent, # role = "Researcher" + ) + + hooks = get_before_tool_call_hooks() + result = hooks[0](context) + assert result is False # Researcher can't use ShellTool + + disable() + + def test_role_based_provider_allows_authorized_agent( + self, mock_tool, mock_agent + ): + provider = RoleBasedProvider( + permissions={"Researcher": ["SearchTool"]} + ) + disable = enable_guardrail(provider) + + context = ToolCallHookContext( + tool_name="SearchTool", + tool_input={}, + tool=mock_tool, + agent=mock_agent, + ) + + hooks = get_before_tool_call_hooks() + result = hooks[0](context) + assert result is None # allowed + + disable() + + def test_fail_closed_blocks_on_exception(self, mock_tool): + """When fail_closed=True (default), provider exceptions block the tool.""" + provider = ExplodingProvider() + disable = enable_guardrail(provider, fail_closed=True) + + context = ToolCallHookContext( + tool_name="any_tool", + tool_input={}, + tool=mock_tool, + ) + + hooks = get_before_tool_call_hooks() + result = hooks[0](context) + assert result is False # blocked due to exception + + disable() + + def test_fail_open_allows_on_exception(self, mock_tool): + """When fail_closed=False, provider exceptions allow the tool.""" + provider = ExplodingProvider() + disable = enable_guardrail(provider, fail_closed=False) + + context = ToolCallHookContext( + tool_name="any_tool", + tool_input={}, + tool=mock_tool, + ) + + hooks = get_before_tool_call_hooks() + result = hooks[0](context) + assert result is None # allowed despite exception + + disable() + + def test_multiple_providers_all_must_allow(self, mock_tool): + """When multiple providers are enabled, all must allow for the tool to proceed.""" + provider1 = AllowAllProvider() + provider2 = BlockListProvider(blocked_tools=["DangerousTool"]) + + disable1 = enable_guardrail(provider1) + disable2 = enable_guardrail(provider2) + + hooks = get_before_tool_call_hooks() + assert len(hooks) == 2 + + # Safe tool – both allow + context_safe = ToolCallHookContext( + tool_name="SafeTool", + tool_input={}, + tool=mock_tool, + ) + results = [h(context_safe) for h in hooks] + assert all(r is None for r in results) + + # Dangerous tool – first allows, second blocks + context_danger = ToolCallHookContext( + tool_name="DangerousTool", + tool_input={}, + tool=mock_tool, + ) + blocked = False + for hook in hooks: + result = hook(context_danger) + if result is False: + blocked = True + break + assert blocked is True + + disable1() + disable2() + + def test_guardrail_request_timestamp_is_set(self, mock_tool): + """The hook should populate the timestamp in the GuardrailRequest.""" + received_requests: list[GuardrailRequest] = [] + + class SpyProvider: + name = "spy" + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + received_requests.append(request) + return GuardrailDecision(allow=True) + + def health_check(self) -> bool: + return True + + provider = SpyProvider() + disable = enable_guardrail(provider) + + context = ToolCallHookContext( + tool_name="tool", + tool_input={"key": "val"}, + tool=mock_tool, + ) + + hooks = get_before_tool_call_hooks() + hooks[0](context) + + assert len(received_requests) == 1 + assert received_requests[0].timestamp != "" + # Should be a valid ISO 8601 string + assert "T" in received_requests[0].timestamp + + disable() + + def test_guardrail_context_fields_passed_through( + self, mock_tool, mock_agent, mock_task, mock_crew + ): + """Verify that agent_role, task_description, crew_id are forwarded.""" + received_requests: list[GuardrailRequest] = [] + + class SpyProvider: + name = "spy" + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + received_requests.append(request) + return GuardrailDecision(allow=True) + + def health_check(self) -> bool: + return True + + provider = SpyProvider() + disable = enable_guardrail(provider) + + context = ToolCallHookContext( + tool_name="search", + tool_input={"q": "test"}, + tool=mock_tool, + agent=mock_agent, + task=mock_task, + crew=mock_crew, + ) + + hooks = get_before_tool_call_hooks() + hooks[0](context) + + req = received_requests[0] + assert req.tool_name == "search" + assert req.tool_input == {"q": "test"} + assert req.agent_role == "Researcher" + assert req.task_description == "Summarize the findings" + assert req.crew_id == "crew-123" + + disable() + + def test_decision_metadata_is_accessible(self, mock_tool): + """Provider metadata in the decision can be used for auditing.""" + + class AuditProvider: + name = "audit" + + def evaluate(self, request: GuardrailRequest) -> GuardrailDecision: + return GuardrailDecision( + allow=True, + metadata={"trace_id": "abc-123", "evaluated_at": request.timestamp}, + ) + + def health_check(self) -> bool: + return True + + provider = AuditProvider() + # Just verify the provider works; metadata is returned but + # not directly exposed by the hook (it's for provider-side use) + req = GuardrailRequest(tool_name="tool", tool_input={}) + decision = provider.evaluate(req) + assert decision.metadata["trace_id"] == "abc-123"