mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
Lorenze/feat hooks (#3902)
* feat: implement LLM call hooks and enhance agent execution context - Introduced LLM call hooks to allow modification of messages and responses during LLM interactions. - Added support for before and after hooks in the CrewAgentExecutor, enabling dynamic adjustments to the execution flow. - Created LLMCallHookContext for comprehensive access to the executor state, facilitating in-place modifications. - Added validation for hook callables to ensure proper functionality. - Enhanced tests for LLM hooks and tool hooks to verify their behavior and error handling capabilities. - Updated LiteAgent and CrewAgentExecutor to accommodate the new crew context in their execution processes. * feat: implement LLM call hooks and enhance agent execution context - Introduced LLM call hooks to allow modification of messages and responses during LLM interactions. - Added support for before and after hooks in the CrewAgentExecutor, enabling dynamic adjustments to the execution flow. - Created LLMCallHookContext for comprehensive access to the executor state, facilitating in-place modifications. - Added validation for hook callables to ensure proper functionality. - Enhanced tests for LLM hooks and tool hooks to verify their behavior and error handling capabilities. - Updated LiteAgent and CrewAgentExecutor to accommodate the new crew context in their execution processes. * fix verbose * feat: introduce crew-scoped hook decorators and refactor hook registration - Added decorators for before and after LLM and tool calls to enhance flexibility in modifying execution behavior. - Implemented a centralized hook registration mechanism within CrewBase to automatically register crew-scoped hooks. - Removed the obsolete base.py file as its functionality has been integrated into the new decorators and registration system. - Enhanced tests for the new hook decorators to ensure proper registration and execution flow. - Updated existing hook handling to accommodate the new decorator-based approach, improving code organization and maintainability. * feat: enhance hook management with clear and unregister functions - Introduced functions to unregister specific before and after hooks for both LLM and tool calls, improving flexibility in hook management. - Added clear functions to remove all registered hooks of each type, facilitating easier state management and cleanup. - Implemented a convenience function to clear all global hooks in one call, streamlining the process for testing and execution context resets. - Enhanced tests to verify the functionality of unregistering and clearing hooks, ensuring robust behavior in various scenarios. * refactor: enhance hook type management for LLM and tool hooks - Updated hook type definitions to use generic protocols for better type safety and flexibility. - Replaced Callable type annotations with specific BeforeLLMCallHookType and AfterLLMCallHookType for clarity. - Improved the registration and retrieval functions for before and after hooks to align with the new type definitions. - Enhanced the setup functions to handle hook execution results, allowing for blocking of LLM calls based on hook logic. - Updated related tests to ensure proper functionality and type adherence across the hook management system. * feat: add execution and tool hooks documentation - Introduced new documentation for execution hooks, LLM call hooks, and tool call hooks to provide comprehensive guidance on their usage and implementation in CrewAI. - Updated existing documentation to include references to the new hooks, enhancing the learning resources available for users. - Ensured consistency across multiple languages (English, Portuguese, Korean) for the new documentation, improving accessibility for a wider audience. - Added examples and troubleshooting sections to assist users in effectively utilizing hooks for agent operations. --------- Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
This commit is contained in:
498
lib/crewai/tests/hooks/test_tool_hooks.py
Normal file
498
lib/crewai/tests/hooks/test_tool_hooks.py
Normal file
@@ -0,0 +1,498 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from crewai.hooks import clear_all_tool_call_hooks, unregister_after_tool_call_hook, unregister_before_tool_call_hook
|
||||
import pytest
|
||||
|
||||
from crewai.hooks.tool_hooks import (
|
||||
ToolCallHookContext,
|
||||
get_after_tool_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
register_after_tool_call_hook,
|
||||
register_before_tool_call_hook,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool for testing."""
|
||||
tool = Mock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "Test tool description"
|
||||
return tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a mock agent for testing."""
|
||||
agent = Mock()
|
||||
agent.role = "Test Agent"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
"""Create a mock task for testing."""
|
||||
task = Mock()
|
||||
task.description = "Test task"
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_crew():
|
||||
"""Create a mock crew for testing."""
|
||||
crew = Mock()
|
||||
return crew
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_hooks():
|
||||
"""Clear global hooks before and after each test."""
|
||||
from crewai.hooks import tool_hooks
|
||||
|
||||
# Store original hooks
|
||||
original_before = tool_hooks._before_tool_call_hooks.copy()
|
||||
original_after = tool_hooks._after_tool_call_hooks.copy()
|
||||
|
||||
# Clear hooks
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
tool_hooks._after_tool_call_hooks.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original hooks
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
tool_hooks._after_tool_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.extend(original_before)
|
||||
tool_hooks._after_tool_call_hooks.extend(original_after)
|
||||
|
||||
|
||||
class TestToolCallHookContext:
|
||||
"""Test ToolCallHookContext initialization and attributes."""
|
||||
|
||||
def test_context_initialization(self, mock_tool, mock_agent, mock_task, mock_crew):
|
||||
"""Test that context is initialized correctly."""
|
||||
tool_input = {"arg1": "value1", "arg2": "value2"}
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
crew=mock_crew,
|
||||
)
|
||||
|
||||
assert context.tool_name == "test_tool"
|
||||
assert context.tool_input == tool_input
|
||||
assert context.tool == mock_tool
|
||||
assert context.agent == mock_agent
|
||||
assert context.task == mock_task
|
||||
assert context.crew == mock_crew
|
||||
assert context.tool_result is None
|
||||
|
||||
def test_context_with_result(self, mock_tool):
|
||||
"""Test that context includes result when provided."""
|
||||
tool_input = {"arg1": "value1"}
|
||||
tool_result = "Test tool result"
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
assert context.tool_result == tool_result
|
||||
|
||||
def test_tool_input_is_mutable_reference(self, mock_tool):
|
||||
"""Test that modifying context.tool_input modifies the original dict."""
|
||||
tool_input = {"arg1": "value1"}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
# Modify through context
|
||||
context.tool_input["arg2"] = "value2"
|
||||
|
||||
# Check that original dict is also modified
|
||||
assert "arg2" in tool_input
|
||||
assert tool_input["arg2"] == "value2"
|
||||
|
||||
|
||||
class TestBeforeToolCallHooks:
|
||||
"""Test before_tool_call hook registration and execution."""
|
||||
|
||||
def test_register_before_hook(self):
|
||||
"""Test that before hooks are registered correctly."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
hooks = get_before_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0] == test_hook
|
||||
|
||||
def test_multiple_before_hooks(self):
|
||||
"""Test that multiple before hooks can be registered."""
|
||||
def hook1(context):
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(hook1)
|
||||
register_before_tool_call_hook(hook2)
|
||||
hooks = get_before_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
assert hook1 in hooks
|
||||
assert hook2 in hooks
|
||||
|
||||
def test_before_hook_can_block_execution(self, mock_tool):
|
||||
"""Test that before hooks can block tool execution."""
|
||||
def block_hook(context):
|
||||
if context.tool_name == "dangerous_tool":
|
||||
return False # Block execution
|
||||
return None # Allow execution
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="dangerous_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
result = block_hook(context)
|
||||
assert result is False
|
||||
|
||||
def test_before_hook_can_allow_execution(self, mock_tool):
|
||||
"""Test that before hooks can explicitly allow execution."""
|
||||
def allow_hook(context):
|
||||
return None # Allow execution
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="safe_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
result = allow_hook(context)
|
||||
assert result is None
|
||||
|
||||
def test_before_hook_can_modify_input(self, mock_tool):
|
||||
"""Test that before hooks can modify tool input in-place."""
|
||||
def modify_input_hook(context):
|
||||
context.tool_input["modified_by_hook"] = True
|
||||
return None
|
||||
|
||||
tool_input = {"arg1": "value1"}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
modify_input_hook(context)
|
||||
|
||||
assert "modified_by_hook" in context.tool_input
|
||||
assert context.tool_input["modified_by_hook"] is True
|
||||
|
||||
def test_get_before_hooks_returns_copy(self):
|
||||
"""Test that get_before_tool_call_hooks returns a copy."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
hooks1 = get_before_tool_call_hooks()
|
||||
hooks2 = get_before_tool_call_hooks()
|
||||
|
||||
# They should be equal but not the same object
|
||||
assert hooks1 == hooks2
|
||||
assert hooks1 is not hooks2
|
||||
|
||||
|
||||
class TestAfterToolCallHooks:
|
||||
"""Test after_tool_call hook registration and execution."""
|
||||
|
||||
def test_register_after_hook(self):
|
||||
"""Test that after hooks are registered correctly."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(test_hook)
|
||||
hooks = get_after_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0] == test_hook
|
||||
|
||||
def test_multiple_after_hooks(self):
|
||||
"""Test that multiple after hooks can be registered."""
|
||||
def hook1(context):
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(hook1)
|
||||
register_after_tool_call_hook(hook2)
|
||||
hooks = get_after_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
assert hook1 in hooks
|
||||
assert hook2 in hooks
|
||||
|
||||
def test_after_hook_can_modify_result(self, mock_tool):
|
||||
"""Test that after hooks can modify the tool result."""
|
||||
original_result = "Original result"
|
||||
|
||||
def modify_result_hook(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result.replace("Original", "Modified")
|
||||
return None
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result=original_result,
|
||||
)
|
||||
|
||||
modified = modify_result_hook(context)
|
||||
assert modified == "Modified result"
|
||||
|
||||
def test_after_hook_returns_none_keeps_original(self, mock_tool):
|
||||
"""Test that returning None keeps the original result."""
|
||||
original_result = "Original result"
|
||||
|
||||
def no_change_hook(context):
|
||||
return None
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result=original_result,
|
||||
)
|
||||
|
||||
result = no_change_hook(context)
|
||||
|
||||
assert result is None
|
||||
assert context.tool_result == original_result
|
||||
|
||||
def test_get_after_hooks_returns_copy(self):
|
||||
"""Test that get_after_tool_call_hooks returns a copy."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(test_hook)
|
||||
hooks1 = get_after_tool_call_hooks()
|
||||
hooks2 = get_after_tool_call_hooks()
|
||||
|
||||
# They should be equal but not the same object
|
||||
assert hooks1 == hooks2
|
||||
assert hooks1 is not hooks2
|
||||
|
||||
|
||||
class TestToolHooksIntegration:
|
||||
"""Test integration scenarios with multiple hooks."""
|
||||
|
||||
def test_multiple_before_hooks_execute_in_order(self, mock_tool):
|
||||
"""Test that multiple before hooks execute in registration order."""
|
||||
execution_order = []
|
||||
|
||||
def hook1(context):
|
||||
execution_order.append(1)
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
execution_order.append(2)
|
||||
return None
|
||||
|
||||
def hook3(context):
|
||||
execution_order.append(3)
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(hook1)
|
||||
register_before_tool_call_hook(hook2)
|
||||
register_before_tool_call_hook(hook3)
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
for hook in hooks:
|
||||
hook(context)
|
||||
|
||||
assert execution_order == [1, 2, 3]
|
||||
|
||||
def test_first_blocking_hook_stops_execution(self, mock_tool):
|
||||
"""Test that first hook returning False blocks execution."""
|
||||
execution_order = []
|
||||
|
||||
def hook1(context):
|
||||
execution_order.append(1)
|
||||
return None # Allow
|
||||
|
||||
def hook2(context):
|
||||
execution_order.append(2)
|
||||
return False # Block
|
||||
|
||||
def hook3(context):
|
||||
execution_order.append(3)
|
||||
return None # This shouldn't run
|
||||
|
||||
register_before_tool_call_hook(hook1)
|
||||
register_before_tool_call_hook(hook2)
|
||||
register_before_tool_call_hook(hook3)
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
blocked = False
|
||||
for hook in hooks:
|
||||
result = hook(context)
|
||||
if result is False:
|
||||
blocked = True
|
||||
break
|
||||
|
||||
assert blocked is True
|
||||
assert execution_order == [1, 2] # hook3 didn't run
|
||||
|
||||
def test_multiple_after_hooks_chain_modifications(self, mock_tool):
|
||||
"""Test that multiple after hooks can chain modifications."""
|
||||
def hook1(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result + " [hook1]"
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result + " [hook2]"
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(hook1)
|
||||
register_after_tool_call_hook(hook2)
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result="Original",
|
||||
)
|
||||
|
||||
hooks = get_after_tool_call_hooks()
|
||||
|
||||
# Simulate chaining (how it would be used in practice)
|
||||
result = context.tool_result
|
||||
for hook in hooks:
|
||||
# Update context for next hook
|
||||
context.tool_result = result
|
||||
modified = hook(context)
|
||||
if modified is not None:
|
||||
result = modified
|
||||
|
||||
assert result == "Original [hook1] [hook2]"
|
||||
|
||||
def test_hooks_with_validation_and_sanitization(self, mock_tool):
|
||||
"""Test a realistic scenario with validation and sanitization hooks."""
|
||||
# Validation hook (before)
|
||||
def validate_file_path(context):
|
||||
if context.tool_name == "write_file":
|
||||
file_path = context.tool_input.get("file_path", "")
|
||||
if ".env" in file_path:
|
||||
return False # Block sensitive files
|
||||
return None
|
||||
|
||||
# Sanitization hook (after)
|
||||
def sanitize_secrets(context):
|
||||
if context.tool_result and "SECRET_KEY" in context.tool_result:
|
||||
return context.tool_result.replace("SECRET_KEY=abc123", "SECRET_KEY=[REDACTED]")
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(validate_file_path)
|
||||
register_after_tool_call_hook(sanitize_secrets)
|
||||
|
||||
# Test blocking
|
||||
blocked_context = ToolCallHookContext(
|
||||
tool_name="write_file",
|
||||
tool_input={"file_path": ".env"},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
before_hooks = get_before_tool_call_hooks()
|
||||
blocked = False
|
||||
for hook in before_hooks:
|
||||
if hook(blocked_context) is False:
|
||||
blocked = True
|
||||
break
|
||||
|
||||
assert blocked is True
|
||||
|
||||
# Test sanitization
|
||||
sanitize_context = ToolCallHookContext(
|
||||
tool_name="read_file",
|
||||
tool_input={"file_path": "config.txt"},
|
||||
tool=mock_tool,
|
||||
tool_result="Content: SECRET_KEY=abc123",
|
||||
)
|
||||
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
result = sanitize_context.tool_result
|
||||
for hook in after_hooks:
|
||||
sanitize_context.tool_result = result
|
||||
modified = hook(sanitize_context)
|
||||
if modified is not None:
|
||||
result = modified
|
||||
|
||||
assert "SECRET_KEY=[REDACTED]" in result
|
||||
assert "abc123" not in result
|
||||
|
||||
|
||||
def test_unregister_before_hook(self):
|
||||
"""Test that before hooks can be unregistered."""
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
unregister_before_tool_call_hook(test_hook)
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
|
||||
def test_unregister_after_hook(self):
|
||||
"""Test that after hooks can be unregistered."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(test_hook)
|
||||
unregister_after_tool_call_hook(test_hook)
|
||||
hooks = get_after_tool_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
|
||||
def test_clear_all_tool_call_hooks(self):
|
||||
"""Test that all tool call hooks can be cleared."""
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
register_after_tool_call_hook(test_hook)
|
||||
clear_all_tool_call_hooks()
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
Reference in New Issue
Block a user