mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-28 17:48:13 +00:00
* feat: implement before and after tool call hooks in CrewAgentExecutor and AgentExecutor - Added support for before and after tool call hooks in both CrewAgentExecutor and AgentExecutor classes. - Introduced ToolCallHookContext to manage context for hooks, allowing for enhanced control over tool execution. - Implemented logic to block tool execution based on before hooks and to modify results based on after hooks. - Added integration tests to validate the functionality of the new hooks, ensuring they work as expected in various scenarios. - Enhanced the overall flexibility and extensibility of tool interactions within the CrewAI framework. * Potential fix for pull request finding 'Unused local variable' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> * Potential fix for pull request finding 'Unused local variable' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> * test: add integration test for before hook blocking tool execution in Crew - Implemented a new test to verify that the before hook can successfully block the execution of a tool within a crew. - The test checks that the tool is not executed when the before hook returns False, ensuring proper control over tool interactions. - Enhanced the validation of hook calls to confirm that both before and after hooks are triggered appropriately, even when execution is blocked. - This addition strengthens the testing coverage for tool call hooks in the CrewAI framework. * drop unused * refactor(tests): remove OPENAI_API_KEY check from tool hook tests - Eliminated the check for the OPENAI_API_KEY environment variable in the test cases for tool hooks. - This change simplifies the test setup and allows for running tests without requiring the API key to be set, improving test accessibility and flexibility. --------- Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
823 lines
27 KiB
Python
823 lines
27 KiB
Python
from __future__ import annotations
|
|
|
|
from unittest.mock import Mock
|
|
|
|
from crewai.hooks import clear_all_tool_call_hooks, unregister_after_tool_call_hook, unregister_before_tool_call_hook
|
|
import pytest
|
|
|
|
from crewai.hooks.tool_hooks import (
|
|
ToolCallHookContext,
|
|
get_after_tool_call_hooks,
|
|
get_before_tool_call_hooks,
|
|
register_after_tool_call_hook,
|
|
register_before_tool_call_hook,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_tool():
|
|
"""Create a mock tool for testing."""
|
|
tool = Mock()
|
|
tool.name = "test_tool"
|
|
tool.description = "Test tool description"
|
|
return tool
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent():
|
|
"""Create a mock agent for testing."""
|
|
agent = Mock()
|
|
agent.role = "Test Agent"
|
|
return agent
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_task():
|
|
"""Create a mock task for testing."""
|
|
task = Mock()
|
|
task.description = "Test task"
|
|
return task
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_crew():
|
|
"""Create a mock crew for testing."""
|
|
crew = Mock()
|
|
return crew
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_hooks():
|
|
"""Clear global hooks before and after each test."""
|
|
from crewai.hooks import tool_hooks
|
|
|
|
# Store original hooks
|
|
original_before = tool_hooks._before_tool_call_hooks.copy()
|
|
original_after = tool_hooks._after_tool_call_hooks.copy()
|
|
|
|
# Clear hooks
|
|
tool_hooks._before_tool_call_hooks.clear()
|
|
tool_hooks._after_tool_call_hooks.clear()
|
|
|
|
yield
|
|
|
|
# Restore original hooks
|
|
tool_hooks._before_tool_call_hooks.clear()
|
|
tool_hooks._after_tool_call_hooks.clear()
|
|
tool_hooks._before_tool_call_hooks.extend(original_before)
|
|
tool_hooks._after_tool_call_hooks.extend(original_after)
|
|
|
|
|
|
class TestToolCallHookContext:
|
|
"""Test ToolCallHookContext initialization and attributes."""
|
|
|
|
def test_context_initialization(self, mock_tool, mock_agent, mock_task, mock_crew):
|
|
"""Test that context is initialized correctly."""
|
|
tool_input = {"arg1": "value1", "arg2": "value2"}
|
|
|
|
context = ToolCallHookContext(
|
|
tool_name="test_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
agent=mock_agent,
|
|
task=mock_task,
|
|
crew=mock_crew,
|
|
)
|
|
|
|
assert context.tool_name == "test_tool"
|
|
assert context.tool_input == tool_input
|
|
assert context.tool == mock_tool
|
|
assert context.agent == mock_agent
|
|
assert context.task == mock_task
|
|
assert context.crew == mock_crew
|
|
assert context.tool_result is None
|
|
|
|
def test_context_with_result(self, mock_tool):
|
|
"""Test that context includes result when provided."""
|
|
tool_input = {"arg1": "value1"}
|
|
tool_result = "Test tool result"
|
|
|
|
context = ToolCallHookContext(
|
|
tool_name="test_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
tool_result=tool_result,
|
|
)
|
|
|
|
assert context.tool_result == tool_result
|
|
|
|
def test_tool_input_is_mutable_reference(self, mock_tool):
|
|
"""Test that modifying context.tool_input modifies the original dict."""
|
|
tool_input = {"arg1": "value1"}
|
|
context = ToolCallHookContext(
|
|
tool_name="test_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
)
|
|
|
|
# Modify through context
|
|
context.tool_input["arg2"] = "value2"
|
|
|
|
# Check that original dict is also modified
|
|
assert "arg2" in tool_input
|
|
assert tool_input["arg2"] == "value2"
|
|
|
|
|
|
class TestBeforeToolCallHooks:
|
|
"""Test before_tool_call hook registration and execution."""
|
|
|
|
def test_register_before_hook(self):
|
|
"""Test that before hooks are registered correctly."""
|
|
def test_hook(context):
|
|
return None
|
|
|
|
register_before_tool_call_hook(test_hook)
|
|
hooks = get_before_tool_call_hooks()
|
|
|
|
assert len(hooks) == 1
|
|
assert hooks[0] == test_hook
|
|
|
|
def test_multiple_before_hooks(self):
|
|
"""Test that multiple before hooks can be registered."""
|
|
def hook1(context):
|
|
return None
|
|
|
|
def hook2(context):
|
|
return None
|
|
|
|
register_before_tool_call_hook(hook1)
|
|
register_before_tool_call_hook(hook2)
|
|
hooks = get_before_tool_call_hooks()
|
|
|
|
assert len(hooks) == 2
|
|
assert hook1 in hooks
|
|
assert hook2 in hooks
|
|
|
|
def test_before_hook_can_block_execution(self, mock_tool):
|
|
"""Test that before hooks can block tool execution."""
|
|
def block_hook(context):
|
|
if context.tool_name == "dangerous_tool":
|
|
return False # Block execution
|
|
return None # Allow execution
|
|
|
|
tool_input = {}
|
|
context = ToolCallHookContext(
|
|
tool_name="dangerous_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
)
|
|
|
|
result = block_hook(context)
|
|
assert result is False
|
|
|
|
def test_before_hook_can_allow_execution(self, mock_tool):
|
|
"""Test that before hooks can explicitly allow execution."""
|
|
def allow_hook(context):
|
|
return None # Allow execution
|
|
|
|
tool_input = {}
|
|
context = ToolCallHookContext(
|
|
tool_name="safe_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
)
|
|
|
|
result = allow_hook(context)
|
|
assert result is None
|
|
|
|
def test_before_hook_can_modify_input(self, mock_tool):
|
|
"""Test that before hooks can modify tool input in-place."""
|
|
def modify_input_hook(context):
|
|
context.tool_input["modified_by_hook"] = True
|
|
return None
|
|
|
|
tool_input = {"arg1": "value1"}
|
|
context = ToolCallHookContext(
|
|
tool_name="test_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
)
|
|
|
|
modify_input_hook(context)
|
|
|
|
assert "modified_by_hook" in context.tool_input
|
|
assert context.tool_input["modified_by_hook"] is True
|
|
|
|
def test_get_before_hooks_returns_copy(self):
|
|
"""Test that get_before_tool_call_hooks returns a copy."""
|
|
def test_hook(context):
|
|
return None
|
|
|
|
register_before_tool_call_hook(test_hook)
|
|
hooks1 = get_before_tool_call_hooks()
|
|
hooks2 = get_before_tool_call_hooks()
|
|
|
|
# They should be equal but not the same object
|
|
assert hooks1 == hooks2
|
|
assert hooks1 is not hooks2
|
|
|
|
|
|
class TestAfterToolCallHooks:
|
|
"""Test after_tool_call hook registration and execution."""
|
|
|
|
def test_register_after_hook(self):
|
|
"""Test that after hooks are registered correctly."""
|
|
def test_hook(context):
|
|
return None
|
|
|
|
register_after_tool_call_hook(test_hook)
|
|
hooks = get_after_tool_call_hooks()
|
|
|
|
assert len(hooks) == 1
|
|
assert hooks[0] == test_hook
|
|
|
|
def test_multiple_after_hooks(self):
|
|
"""Test that multiple after hooks can be registered."""
|
|
def hook1(context):
|
|
return None
|
|
|
|
def hook2(context):
|
|
return None
|
|
|
|
register_after_tool_call_hook(hook1)
|
|
register_after_tool_call_hook(hook2)
|
|
hooks = get_after_tool_call_hooks()
|
|
|
|
assert len(hooks) == 2
|
|
assert hook1 in hooks
|
|
assert hook2 in hooks
|
|
|
|
def test_after_hook_can_modify_result(self, mock_tool):
|
|
"""Test that after hooks can modify the tool result."""
|
|
original_result = "Original result"
|
|
|
|
def modify_result_hook(context):
|
|
if context.tool_result:
|
|
return context.tool_result.replace("Original", "Modified")
|
|
return None
|
|
|
|
tool_input = {}
|
|
context = ToolCallHookContext(
|
|
tool_name="test_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
tool_result=original_result,
|
|
)
|
|
|
|
modified = modify_result_hook(context)
|
|
assert modified == "Modified result"
|
|
|
|
def test_after_hook_returns_none_keeps_original(self, mock_tool):
|
|
"""Test that returning None keeps the original result."""
|
|
original_result = "Original result"
|
|
|
|
def no_change_hook(context):
|
|
return None
|
|
|
|
tool_input = {}
|
|
context = ToolCallHookContext(
|
|
tool_name="test_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
tool_result=original_result,
|
|
)
|
|
|
|
result = no_change_hook(context)
|
|
|
|
assert result is None
|
|
assert context.tool_result == original_result
|
|
|
|
def test_get_after_hooks_returns_copy(self):
|
|
"""Test that get_after_tool_call_hooks returns a copy."""
|
|
def test_hook(context):
|
|
return None
|
|
|
|
register_after_tool_call_hook(test_hook)
|
|
hooks1 = get_after_tool_call_hooks()
|
|
hooks2 = get_after_tool_call_hooks()
|
|
|
|
# They should be equal but not the same object
|
|
assert hooks1 == hooks2
|
|
assert hooks1 is not hooks2
|
|
|
|
|
|
class TestToolHooksIntegration:
|
|
"""Test integration scenarios with multiple hooks."""
|
|
|
|
def test_multiple_before_hooks_execute_in_order(self, mock_tool):
|
|
"""Test that multiple before hooks execute in registration order."""
|
|
execution_order = []
|
|
|
|
def hook1(context):
|
|
execution_order.append(1)
|
|
return None
|
|
|
|
def hook2(context):
|
|
execution_order.append(2)
|
|
return None
|
|
|
|
def hook3(context):
|
|
execution_order.append(3)
|
|
return None
|
|
|
|
register_before_tool_call_hook(hook1)
|
|
register_before_tool_call_hook(hook2)
|
|
register_before_tool_call_hook(hook3)
|
|
|
|
tool_input = {}
|
|
context = ToolCallHookContext(
|
|
tool_name="test_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
)
|
|
|
|
hooks = get_before_tool_call_hooks()
|
|
for hook in hooks:
|
|
hook(context)
|
|
|
|
assert execution_order == [1, 2, 3]
|
|
|
|
def test_first_blocking_hook_stops_execution(self, mock_tool):
|
|
"""Test that first hook returning False blocks execution."""
|
|
execution_order = []
|
|
|
|
def hook1(context):
|
|
execution_order.append(1)
|
|
return None # Allow
|
|
|
|
def hook2(context):
|
|
execution_order.append(2)
|
|
return False # Block
|
|
|
|
def hook3(context):
|
|
execution_order.append(3)
|
|
return None # This shouldn't run
|
|
|
|
register_before_tool_call_hook(hook1)
|
|
register_before_tool_call_hook(hook2)
|
|
register_before_tool_call_hook(hook3)
|
|
|
|
tool_input = {}
|
|
context = ToolCallHookContext(
|
|
tool_name="test_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
)
|
|
|
|
hooks = get_before_tool_call_hooks()
|
|
blocked = False
|
|
for hook in hooks:
|
|
result = hook(context)
|
|
if result is False:
|
|
blocked = True
|
|
break
|
|
|
|
assert blocked is True
|
|
assert execution_order == [1, 2] # hook3 didn't run
|
|
|
|
def test_multiple_after_hooks_chain_modifications(self, mock_tool):
|
|
"""Test that multiple after hooks can chain modifications."""
|
|
def hook1(context):
|
|
if context.tool_result:
|
|
return context.tool_result + " [hook1]"
|
|
return None
|
|
|
|
def hook2(context):
|
|
if context.tool_result:
|
|
return context.tool_result + " [hook2]"
|
|
return None
|
|
|
|
register_after_tool_call_hook(hook1)
|
|
register_after_tool_call_hook(hook2)
|
|
|
|
tool_input = {}
|
|
context = ToolCallHookContext(
|
|
tool_name="test_tool",
|
|
tool_input=tool_input,
|
|
tool=mock_tool,
|
|
tool_result="Original",
|
|
)
|
|
|
|
hooks = get_after_tool_call_hooks()
|
|
|
|
# Simulate chaining (how it would be used in practice)
|
|
result = context.tool_result
|
|
for hook in hooks:
|
|
# Update context for next hook
|
|
context.tool_result = result
|
|
modified = hook(context)
|
|
if modified is not None:
|
|
result = modified
|
|
|
|
assert result == "Original [hook1] [hook2]"
|
|
|
|
def test_hooks_with_validation_and_sanitization(self, mock_tool):
|
|
"""Test a realistic scenario with validation and sanitization hooks."""
|
|
# Validation hook (before)
|
|
def validate_file_path(context):
|
|
if context.tool_name == "write_file":
|
|
file_path = context.tool_input.get("file_path", "")
|
|
if ".env" in file_path:
|
|
return False # Block sensitive files
|
|
return None
|
|
|
|
# Sanitization hook (after)
|
|
def sanitize_secrets(context):
|
|
if context.tool_result and "SECRET_KEY" in context.tool_result:
|
|
return context.tool_result.replace("SECRET_KEY=abc123", "SECRET_KEY=[REDACTED]")
|
|
return None
|
|
|
|
register_before_tool_call_hook(validate_file_path)
|
|
register_after_tool_call_hook(sanitize_secrets)
|
|
|
|
# Test blocking
|
|
blocked_context = ToolCallHookContext(
|
|
tool_name="write_file",
|
|
tool_input={"file_path": ".env"},
|
|
tool=mock_tool,
|
|
)
|
|
|
|
before_hooks = get_before_tool_call_hooks()
|
|
blocked = False
|
|
for hook in before_hooks:
|
|
if hook(blocked_context) is False:
|
|
blocked = True
|
|
break
|
|
|
|
assert blocked is True
|
|
|
|
# Test sanitization
|
|
sanitize_context = ToolCallHookContext(
|
|
tool_name="read_file",
|
|
tool_input={"file_path": "config.txt"},
|
|
tool=mock_tool,
|
|
tool_result="Content: SECRET_KEY=abc123",
|
|
)
|
|
|
|
after_hooks = get_after_tool_call_hooks()
|
|
result = sanitize_context.tool_result
|
|
for hook in after_hooks:
|
|
sanitize_context.tool_result = result
|
|
modified = hook(sanitize_context)
|
|
if modified is not None:
|
|
result = modified
|
|
|
|
assert "SECRET_KEY=[REDACTED]" in result
|
|
assert "abc123" not in result
|
|
|
|
|
|
def test_unregister_before_hook(self):
|
|
"""Test that before hooks can be unregistered."""
|
|
def test_hook(context):
|
|
pass
|
|
|
|
register_before_tool_call_hook(test_hook)
|
|
unregister_before_tool_call_hook(test_hook)
|
|
hooks = get_before_tool_call_hooks()
|
|
assert len(hooks) == 0
|
|
|
|
def test_unregister_after_hook(self):
|
|
"""Test that after hooks can be unregistered."""
|
|
def test_hook(context):
|
|
return None
|
|
|
|
register_after_tool_call_hook(test_hook)
|
|
unregister_after_tool_call_hook(test_hook)
|
|
hooks = get_after_tool_call_hooks()
|
|
assert len(hooks) == 0
|
|
|
|
def test_clear_all_tool_call_hooks(self):
|
|
"""Test that all tool call hooks can be cleared."""
|
|
def test_hook(context):
|
|
pass
|
|
|
|
register_before_tool_call_hook(test_hook)
|
|
register_after_tool_call_hook(test_hook)
|
|
clear_all_tool_call_hooks()
|
|
hooks = get_before_tool_call_hooks()
|
|
assert len(hooks) == 0
|
|
|
|
@pytest.mark.vcr()
|
|
def test_lite_agent_hooks_integration_with_real_tool(self):
|
|
"""Test that LiteAgent executes before/after tool call hooks with real tool calls."""
|
|
import os
|
|
from crewai.lite_agent import LiteAgent
|
|
from crewai.tools import tool
|
|
|
|
# Skip if no API key available
|
|
if not os.environ.get("OPENAI_API_KEY"):
|
|
pytest.skip("OPENAI_API_KEY not set - skipping real tool test")
|
|
|
|
# Track hook invocations
|
|
hook_calls = {"before": [], "after": []}
|
|
|
|
# Create a simple test tool
|
|
@tool("calculate_sum")
|
|
def calculate_sum(a: int, b: int) -> int:
|
|
"""Add two numbers together."""
|
|
return a + b
|
|
|
|
def before_tool_call_hook(context: ToolCallHookContext) -> bool:
|
|
"""Log and verify before hook execution."""
|
|
print(f"\n[BEFORE HOOK] Tool: {context.tool_name}")
|
|
print(f"[BEFORE HOOK] Tool input: {context.tool_input}")
|
|
print(f"[BEFORE HOOK] Agent: {context.agent.role if context.agent else 'None'}")
|
|
print(f"[BEFORE HOOK] Task: {context.task}")
|
|
print(f"[BEFORE HOOK] Crew: {context.crew}")
|
|
|
|
# Track the call
|
|
hook_calls["before"].append({
|
|
"tool_name": context.tool_name,
|
|
"tool_input": context.tool_input,
|
|
"has_agent": context.agent is not None,
|
|
"has_task": context.task is not None,
|
|
"has_crew": context.crew is not None,
|
|
})
|
|
|
|
return True # Allow execution
|
|
|
|
def after_tool_call_hook(context: ToolCallHookContext) -> str | None:
|
|
"""Log and verify after hook execution."""
|
|
print(f"\n[AFTER HOOK] Tool: {context.tool_name}")
|
|
print(f"[AFTER HOOK] Tool result: {context.tool_result}")
|
|
print(f"[AFTER HOOK] Agent: {context.agent.role if context.agent else 'None'}")
|
|
|
|
# Track the call
|
|
hook_calls["after"].append({
|
|
"tool_name": context.tool_name,
|
|
"tool_result": context.tool_result,
|
|
"has_result": context.tool_result is not None,
|
|
})
|
|
|
|
return None # Don't modify result
|
|
|
|
# Register hooks
|
|
register_before_tool_call_hook(before_tool_call_hook)
|
|
register_after_tool_call_hook(after_tool_call_hook)
|
|
|
|
try:
|
|
# Create LiteAgent with the tool
|
|
lite_agent = LiteAgent(
|
|
role="Calculator Assistant",
|
|
goal="Help with math calculations",
|
|
backstory="You are a helpful calculator assistant",
|
|
tools=[calculate_sum],
|
|
verbose=True,
|
|
)
|
|
|
|
# Execute with a prompt that should trigger tool usage
|
|
result = lite_agent.kickoff("What is 5 + 3? Use the calculate_sum tool.")
|
|
|
|
# Verify hooks were called
|
|
assert len(hook_calls["before"]) > 0, "Before hook was never called"
|
|
assert len(hook_calls["after"]) > 0, "After hook was never called"
|
|
|
|
# Verify context had correct attributes for LiteAgent (used in flows)
|
|
# LiteAgent doesn't have task/crew context, unlike agents in CrewBase
|
|
before_call = hook_calls["before"][0]
|
|
assert before_call["tool_name"] == "calculate_sum", "Tool name should be 'calculate_sum'"
|
|
assert "a" in before_call["tool_input"], "Tool input should have 'a' parameter"
|
|
assert "b" in before_call["tool_input"], "Tool input should have 'b' parameter"
|
|
|
|
# Verify after hook received result
|
|
after_call = hook_calls["after"][0]
|
|
assert after_call["has_result"] is True, "After hook should have tool result"
|
|
assert after_call["tool_name"] == "calculate_sum", "Tool name should match"
|
|
# The result should contain the sum (8)
|
|
assert "8" in str(after_call["tool_result"]), "Tool result should contain the sum"
|
|
|
|
finally:
|
|
# Clean up hooks
|
|
unregister_before_tool_call_hook(before_tool_call_hook)
|
|
unregister_after_tool_call_hook(after_tool_call_hook)
|
|
|
|
|
|
class TestNativeToolCallingHooksIntegration:
|
|
"""Integration tests for hooks with native function calling (Agent and Crew)."""
|
|
|
|
@pytest.mark.vcr()
|
|
def test_agent_native_tool_hooks_before_and_after(self):
|
|
"""Test that Agent with native tool calling executes before/after hooks."""
|
|
import os
|
|
from crewai import Agent
|
|
from crewai.tools import tool
|
|
|
|
hook_calls = {"before": [], "after": []}
|
|
|
|
@tool("multiply_numbers")
|
|
def multiply_numbers(a: int, b: int) -> int:
|
|
"""Multiply two numbers together."""
|
|
return a * b
|
|
|
|
def before_hook(context: ToolCallHookContext) -> bool | None:
|
|
hook_calls["before"].append({
|
|
"tool_name": context.tool_name,
|
|
"tool_input": dict(context.tool_input),
|
|
"has_agent": context.agent is not None,
|
|
})
|
|
return None
|
|
|
|
def after_hook(context: ToolCallHookContext) -> str | None:
|
|
hook_calls["after"].append({
|
|
"tool_name": context.tool_name,
|
|
"tool_result": context.tool_result,
|
|
"has_agent": context.agent is not None,
|
|
})
|
|
return None
|
|
|
|
register_before_tool_call_hook(before_hook)
|
|
register_after_tool_call_hook(after_hook)
|
|
|
|
try:
|
|
agent = Agent(
|
|
role="Calculator",
|
|
goal="Perform calculations",
|
|
backstory="You are a calculator assistant",
|
|
tools=[multiply_numbers],
|
|
verbose=True,
|
|
)
|
|
|
|
agent.kickoff(
|
|
messages="What is 7 times 6? Use the multiply_numbers tool."
|
|
)
|
|
|
|
# Verify before hook was called
|
|
assert len(hook_calls["before"]) > 0, "Before hook was never called"
|
|
before_call = hook_calls["before"][0]
|
|
assert before_call["tool_name"] == "multiply_numbers"
|
|
assert "a" in before_call["tool_input"]
|
|
assert "b" in before_call["tool_input"]
|
|
assert before_call["has_agent"] is True
|
|
|
|
# Verify after hook was called
|
|
assert len(hook_calls["after"]) > 0, "After hook was never called"
|
|
after_call = hook_calls["after"][0]
|
|
assert after_call["tool_name"] == "multiply_numbers"
|
|
assert "42" in str(after_call["tool_result"])
|
|
assert after_call["has_agent"] is True
|
|
|
|
finally:
|
|
unregister_before_tool_call_hook(before_hook)
|
|
unregister_after_tool_call_hook(after_hook)
|
|
|
|
@pytest.mark.vcr()
|
|
def test_crew_native_tool_hooks_before_and_after(self):
|
|
"""Test that Crew with Agent executes before/after hooks with full context."""
|
|
import os
|
|
from crewai import Agent, Crew, Task
|
|
from crewai.tools import tool
|
|
|
|
|
|
hook_calls = {"before": [], "after": []}
|
|
|
|
@tool("divide_numbers")
|
|
def divide_numbers(a: int, b: int) -> float:
|
|
"""Divide first number by second number."""
|
|
return a / b
|
|
|
|
def before_hook(context: ToolCallHookContext) -> bool | None:
|
|
hook_calls["before"].append({
|
|
"tool_name": context.tool_name,
|
|
"tool_input": dict(context.tool_input),
|
|
"has_agent": context.agent is not None,
|
|
"has_task": context.task is not None,
|
|
"has_crew": context.crew is not None,
|
|
"agent_role": context.agent.role if context.agent else None,
|
|
})
|
|
return None
|
|
|
|
def after_hook(context: ToolCallHookContext) -> str | None:
|
|
hook_calls["after"].append({
|
|
"tool_name": context.tool_name,
|
|
"tool_result": context.tool_result,
|
|
"has_agent": context.agent is not None,
|
|
"has_task": context.task is not None,
|
|
"has_crew": context.crew is not None,
|
|
})
|
|
return None
|
|
|
|
register_before_tool_call_hook(before_hook)
|
|
register_after_tool_call_hook(after_hook)
|
|
|
|
try:
|
|
agent = Agent(
|
|
role="Math Assistant",
|
|
goal="Perform division calculations accurately",
|
|
backstory="You are a math assistant that helps with division",
|
|
tools=[divide_numbers],
|
|
verbose=True,
|
|
)
|
|
|
|
task = Task(
|
|
description="Calculate 100 divided by 4 using the divide_numbers tool.",
|
|
expected_output="The result of the division",
|
|
agent=agent,
|
|
)
|
|
|
|
crew = Crew(
|
|
agents=[agent],
|
|
tasks=[task],
|
|
verbose=True,
|
|
)
|
|
|
|
crew.kickoff()
|
|
|
|
# Verify before hook was called with full context
|
|
assert len(hook_calls["before"]) > 0, "Before hook was never called"
|
|
before_call = hook_calls["before"][0]
|
|
assert before_call["tool_name"] == "divide_numbers"
|
|
assert "a" in before_call["tool_input"]
|
|
assert "b" in before_call["tool_input"]
|
|
assert before_call["has_agent"] is True
|
|
assert before_call["has_task"] is True
|
|
assert before_call["has_crew"] is True
|
|
assert before_call["agent_role"] == "Math Assistant"
|
|
|
|
# Verify after hook was called with full context
|
|
assert len(hook_calls["after"]) > 0, "After hook was never called"
|
|
after_call = hook_calls["after"][0]
|
|
assert after_call["tool_name"] == "divide_numbers"
|
|
assert "25" in str(after_call["tool_result"])
|
|
assert after_call["has_agent"] is True
|
|
assert after_call["has_task"] is True
|
|
assert after_call["has_crew"] is True
|
|
|
|
finally:
|
|
unregister_before_tool_call_hook(before_hook)
|
|
unregister_after_tool_call_hook(after_hook)
|
|
|
|
@pytest.mark.vcr()
|
|
def test_before_hook_blocks_tool_execution_in_crew(self):
|
|
"""Test that returning False from before hook blocks tool execution."""
|
|
import os
|
|
from crewai import Agent, Crew, Task
|
|
from crewai.tools import tool
|
|
|
|
hook_calls = {"before": [], "after": [], "tool_executed": False}
|
|
|
|
@tool("dangerous_operation")
|
|
def dangerous_operation(action: str) -> str:
|
|
"""Perform a dangerous operation that should be blocked."""
|
|
hook_calls["tool_executed"] = True
|
|
return f"Executed: {action}"
|
|
|
|
def blocking_before_hook(context: ToolCallHookContext) -> bool | None:
|
|
hook_calls["before"].append({
|
|
"tool_name": context.tool_name,
|
|
"tool_input": dict(context.tool_input),
|
|
})
|
|
# Block all calls to dangerous_operation
|
|
if context.tool_name == "dangerous_operation":
|
|
return False
|
|
return None
|
|
|
|
def after_hook(context: ToolCallHookContext) -> str | None:
|
|
hook_calls["after"].append({
|
|
"tool_name": context.tool_name,
|
|
"tool_result": context.tool_result,
|
|
})
|
|
return None
|
|
|
|
register_before_tool_call_hook(blocking_before_hook)
|
|
register_after_tool_call_hook(after_hook)
|
|
|
|
try:
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Try to use the dangerous operation tool",
|
|
backstory="You are a test agent",
|
|
tools=[dangerous_operation],
|
|
verbose=True,
|
|
)
|
|
|
|
task = Task(
|
|
description="Use the dangerous_operation tool with action 'delete_all'.",
|
|
expected_output="The result of the operation",
|
|
agent=agent,
|
|
)
|
|
|
|
crew = Crew(
|
|
agents=[agent],
|
|
tasks=[task],
|
|
verbose=True,
|
|
)
|
|
|
|
crew.kickoff()
|
|
|
|
# Verify before hook was called
|
|
assert len(hook_calls["before"]) > 0, "Before hook was never called"
|
|
before_call = hook_calls["before"][0]
|
|
assert before_call["tool_name"] == "dangerous_operation"
|
|
|
|
# Verify the actual tool function was NOT executed
|
|
assert hook_calls["tool_executed"] is False, "Tool should have been blocked"
|
|
|
|
# Verify after hook was still called (with blocked message)
|
|
assert len(hook_calls["after"]) > 0, "After hook was never called"
|
|
after_call = hook_calls["after"][0]
|
|
assert "blocked" in after_call["tool_result"].lower()
|
|
|
|
finally:
|
|
unregister_before_tool_call_hook(blocking_before_hook)
|
|
unregister_after_tool_call_hook(after_hook)
|