from __future__ import annotations from unittest.mock import Mock from crewai.hooks import clear_all_tool_call_hooks, unregister_after_tool_call_hook, unregister_before_tool_call_hook import pytest from crewai.hooks.tool_hooks import ( ToolCallHookContext, get_after_tool_call_hooks, get_before_tool_call_hooks, register_after_tool_call_hook, register_before_tool_call_hook, ) @pytest.fixture def mock_tool(): """Create a mock tool for testing.""" tool = Mock() tool.name = "test_tool" tool.description = "Test tool description" return tool @pytest.fixture def mock_agent(): """Create a mock agent for testing.""" agent = Mock() agent.role = "Test Agent" return agent @pytest.fixture def mock_task(): """Create a mock task for testing.""" task = Mock() task.description = "Test task" return task @pytest.fixture def mock_crew(): """Create a mock crew for testing.""" crew = Mock() return crew @pytest.fixture(autouse=True) def clear_hooks(): """Clear global hooks before and after each test.""" from crewai.hooks import tool_hooks # Store original hooks original_before = tool_hooks._before_tool_call_hooks.copy() original_after = tool_hooks._after_tool_call_hooks.copy() # Clear hooks tool_hooks._before_tool_call_hooks.clear() tool_hooks._after_tool_call_hooks.clear() yield # Restore original hooks tool_hooks._before_tool_call_hooks.clear() tool_hooks._after_tool_call_hooks.clear() tool_hooks._before_tool_call_hooks.extend(original_before) tool_hooks._after_tool_call_hooks.extend(original_after) class TestToolCallHookContext: """Test ToolCallHookContext initialization and attributes.""" def test_context_initialization(self, mock_tool, mock_agent, mock_task, mock_crew): """Test that context is initialized correctly.""" tool_input = {"arg1": "value1", "arg2": "value2"} context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, agent=mock_agent, task=mock_task, crew=mock_crew, ) assert context.tool_name == "test_tool" assert context.tool_input == tool_input assert context.tool == mock_tool assert context.agent == mock_agent assert context.task == mock_task assert context.crew == mock_crew assert context.tool_result is None def test_context_with_result(self, mock_tool): """Test that context includes result when provided.""" tool_input = {"arg1": "value1"} tool_result = "Test tool result" context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, tool_result=tool_result, ) assert context.tool_result == tool_result def test_tool_input_is_mutable_reference(self, mock_tool): """Test that modifying context.tool_input modifies the original dict.""" tool_input = {"arg1": "value1"} context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, ) # Modify through context context.tool_input["arg2"] = "value2" # Check that original dict is also modified assert "arg2" in tool_input assert tool_input["arg2"] == "value2" class TestBeforeToolCallHooks: """Test before_tool_call hook registration and execution.""" def test_register_before_hook(self): """Test that before hooks are registered correctly.""" def test_hook(context): return None register_before_tool_call_hook(test_hook) hooks = get_before_tool_call_hooks() assert len(hooks) == 1 assert hooks[0] == test_hook def test_multiple_before_hooks(self): """Test that multiple before hooks can be registered.""" def hook1(context): return None def hook2(context): return None register_before_tool_call_hook(hook1) register_before_tool_call_hook(hook2) hooks = get_before_tool_call_hooks() assert len(hooks) == 2 assert hook1 in hooks assert hook2 in hooks def test_before_hook_can_block_execution(self, mock_tool): """Test that before hooks can block tool execution.""" def block_hook(context): if context.tool_name == "dangerous_tool": return False # Block execution return None # Allow execution tool_input = {} context = ToolCallHookContext( tool_name="dangerous_tool", tool_input=tool_input, tool=mock_tool, ) result = block_hook(context) assert result is False def test_before_hook_can_allow_execution(self, mock_tool): """Test that before hooks can explicitly allow execution.""" def allow_hook(context): return None # Allow execution tool_input = {} context = ToolCallHookContext( tool_name="safe_tool", tool_input=tool_input, tool=mock_tool, ) result = allow_hook(context) assert result is None def test_before_hook_can_modify_input(self, mock_tool): """Test that before hooks can modify tool input in-place.""" def modify_input_hook(context): context.tool_input["modified_by_hook"] = True return None tool_input = {"arg1": "value1"} context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, ) modify_input_hook(context) assert "modified_by_hook" in context.tool_input assert context.tool_input["modified_by_hook"] is True def test_get_before_hooks_returns_copy(self): """Test that get_before_tool_call_hooks returns a copy.""" def test_hook(context): return None register_before_tool_call_hook(test_hook) hooks1 = get_before_tool_call_hooks() hooks2 = get_before_tool_call_hooks() # They should be equal but not the same object assert hooks1 == hooks2 assert hooks1 is not hooks2 class TestAfterToolCallHooks: """Test after_tool_call hook registration and execution.""" def test_register_after_hook(self): """Test that after hooks are registered correctly.""" def test_hook(context): return None register_after_tool_call_hook(test_hook) hooks = get_after_tool_call_hooks() assert len(hooks) == 1 assert hooks[0] == test_hook def test_multiple_after_hooks(self): """Test that multiple after hooks can be registered.""" def hook1(context): return None def hook2(context): return None register_after_tool_call_hook(hook1) register_after_tool_call_hook(hook2) hooks = get_after_tool_call_hooks() assert len(hooks) == 2 assert hook1 in hooks assert hook2 in hooks def test_after_hook_can_modify_result(self, mock_tool): """Test that after hooks can modify the tool result.""" original_result = "Original result" def modify_result_hook(context): if context.tool_result: return context.tool_result.replace("Original", "Modified") return None tool_input = {} context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, tool_result=original_result, ) modified = modify_result_hook(context) assert modified == "Modified result" def test_after_hook_returns_none_keeps_original(self, mock_tool): """Test that returning None keeps the original result.""" original_result = "Original result" def no_change_hook(context): return None tool_input = {} context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, tool_result=original_result, ) result = no_change_hook(context) assert result is None assert context.tool_result == original_result def test_get_after_hooks_returns_copy(self): """Test that get_after_tool_call_hooks returns a copy.""" def test_hook(context): return None register_after_tool_call_hook(test_hook) hooks1 = get_after_tool_call_hooks() hooks2 = get_after_tool_call_hooks() # They should be equal but not the same object assert hooks1 == hooks2 assert hooks1 is not hooks2 class TestToolHooksIntegration: """Test integration scenarios with multiple hooks.""" def test_multiple_before_hooks_execute_in_order(self, mock_tool): """Test that multiple before hooks execute in registration order.""" execution_order = [] def hook1(context): execution_order.append(1) return None def hook2(context): execution_order.append(2) return None def hook3(context): execution_order.append(3) return None register_before_tool_call_hook(hook1) register_before_tool_call_hook(hook2) register_before_tool_call_hook(hook3) tool_input = {} context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, ) hooks = get_before_tool_call_hooks() for hook in hooks: hook(context) assert execution_order == [1, 2, 3] def test_first_blocking_hook_stops_execution(self, mock_tool): """Test that first hook returning False blocks execution.""" execution_order = [] def hook1(context): execution_order.append(1) return None # Allow def hook2(context): execution_order.append(2) return False # Block def hook3(context): execution_order.append(3) return None # This shouldn't run register_before_tool_call_hook(hook1) register_before_tool_call_hook(hook2) register_before_tool_call_hook(hook3) tool_input = {} context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, ) hooks = get_before_tool_call_hooks() blocked = False for hook in hooks: result = hook(context) if result is False: blocked = True break assert blocked is True assert execution_order == [1, 2] # hook3 didn't run def test_multiple_after_hooks_chain_modifications(self, mock_tool): """Test that multiple after hooks can chain modifications.""" def hook1(context): if context.tool_result: return context.tool_result + " [hook1]" return None def hook2(context): if context.tool_result: return context.tool_result + " [hook2]" return None register_after_tool_call_hook(hook1) register_after_tool_call_hook(hook2) tool_input = {} context = ToolCallHookContext( tool_name="test_tool", tool_input=tool_input, tool=mock_tool, tool_result="Original", ) hooks = get_after_tool_call_hooks() # Simulate chaining (how it would be used in practice) result = context.tool_result for hook in hooks: # Update context for next hook context.tool_result = result modified = hook(context) if modified is not None: result = modified assert result == "Original [hook1] [hook2]" def test_hooks_with_validation_and_sanitization(self, mock_tool): """Test a realistic scenario with validation and sanitization hooks.""" # Validation hook (before) def validate_file_path(context): if context.tool_name == "write_file": file_path = context.tool_input.get("file_path", "") if ".env" in file_path: return False # Block sensitive files return None # Sanitization hook (after) def sanitize_secrets(context): if context.tool_result and "SECRET_KEY" in context.tool_result: return context.tool_result.replace("SECRET_KEY=abc123", "SECRET_KEY=[REDACTED]") return None register_before_tool_call_hook(validate_file_path) register_after_tool_call_hook(sanitize_secrets) # Test blocking blocked_context = ToolCallHookContext( tool_name="write_file", tool_input={"file_path": ".env"}, tool=mock_tool, ) before_hooks = get_before_tool_call_hooks() blocked = False for hook in before_hooks: if hook(blocked_context) is False: blocked = True break assert blocked is True # Test sanitization sanitize_context = ToolCallHookContext( tool_name="read_file", tool_input={"file_path": "config.txt"}, tool=mock_tool, tool_result="Content: SECRET_KEY=abc123", ) after_hooks = get_after_tool_call_hooks() result = sanitize_context.tool_result for hook in after_hooks: sanitize_context.tool_result = result modified = hook(sanitize_context) if modified is not None: result = modified assert "SECRET_KEY=[REDACTED]" in result assert "abc123" not in result def test_unregister_before_hook(self): """Test that before hooks can be unregistered.""" def test_hook(context): pass register_before_tool_call_hook(test_hook) unregister_before_tool_call_hook(test_hook) hooks = get_before_tool_call_hooks() assert len(hooks) == 0 def test_unregister_after_hook(self): """Test that after hooks can be unregistered.""" def test_hook(context): return None register_after_tool_call_hook(test_hook) unregister_after_tool_call_hook(test_hook) hooks = get_after_tool_call_hooks() assert len(hooks) == 0 def test_clear_all_tool_call_hooks(self): """Test that all tool call hooks can be cleared.""" def test_hook(context): pass register_before_tool_call_hook(test_hook) register_after_tool_call_hook(test_hook) clear_all_tool_call_hooks() hooks = get_before_tool_call_hooks() assert len(hooks) == 0