mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
Lorenze/feat hooks (#3902)
* feat: implement LLM call hooks and enhance agent execution context - Introduced LLM call hooks to allow modification of messages and responses during LLM interactions. - Added support for before and after hooks in the CrewAgentExecutor, enabling dynamic adjustments to the execution flow. - Created LLMCallHookContext for comprehensive access to the executor state, facilitating in-place modifications. - Added validation for hook callables to ensure proper functionality. - Enhanced tests for LLM hooks and tool hooks to verify their behavior and error handling capabilities. - Updated LiteAgent and CrewAgentExecutor to accommodate the new crew context in their execution processes. * feat: implement LLM call hooks and enhance agent execution context - Introduced LLM call hooks to allow modification of messages and responses during LLM interactions. - Added support for before and after hooks in the CrewAgentExecutor, enabling dynamic adjustments to the execution flow. - Created LLMCallHookContext for comprehensive access to the executor state, facilitating in-place modifications. - Added validation for hook callables to ensure proper functionality. - Enhanced tests for LLM hooks and tool hooks to verify their behavior and error handling capabilities. - Updated LiteAgent and CrewAgentExecutor to accommodate the new crew context in their execution processes. * fix verbose * feat: introduce crew-scoped hook decorators and refactor hook registration - Added decorators for before and after LLM and tool calls to enhance flexibility in modifying execution behavior. - Implemented a centralized hook registration mechanism within CrewBase to automatically register crew-scoped hooks. - Removed the obsolete base.py file as its functionality has been integrated into the new decorators and registration system. - Enhanced tests for the new hook decorators to ensure proper registration and execution flow. - Updated existing hook handling to accommodate the new decorator-based approach, improving code organization and maintainability. * feat: enhance hook management with clear and unregister functions - Introduced functions to unregister specific before and after hooks for both LLM and tool calls, improving flexibility in hook management. - Added clear functions to remove all registered hooks of each type, facilitating easier state management and cleanup. - Implemented a convenience function to clear all global hooks in one call, streamlining the process for testing and execution context resets. - Enhanced tests to verify the functionality of unregistering and clearing hooks, ensuring robust behavior in various scenarios. * refactor: enhance hook type management for LLM and tool hooks - Updated hook type definitions to use generic protocols for better type safety and flexibility. - Replaced Callable type annotations with specific BeforeLLMCallHookType and AfterLLMCallHookType for clarity. - Improved the registration and retrieval functions for before and after hooks to align with the new type definitions. - Enhanced the setup functions to handle hook execution results, allowing for blocking of LLM calls based on hook logic. - Updated related tests to ensure proper functionality and type adherence across the hook management system. * feat: add execution and tool hooks documentation - Introduced new documentation for execution hooks, LLM call hooks, and tool call hooks to provide comprehensive guidance on their usage and implementation in CrewAI. - Updated existing documentation to include references to the new hooks, enhancing the learning resources available for users. - Ensured consistency across multiple languages (English, Portuguese, Korean) for the new documentation, improving accessibility for a wider audience. - Added examples and troubleshooting sections to assist users in effectively utilizing hooks for agent operations. --------- Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
This commit is contained in:
2
lib/crewai/tests/hooks/__init__.py
Normal file
2
lib/crewai/tests/hooks/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Tests for CrewAI hooks functionality."""
|
||||
|
||||
619
lib/crewai/tests/hooks/test_crew_scoped_hooks.py
Normal file
619
lib/crewai/tests/hooks/test_crew_scoped_hooks.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""Tests for crew-scoped hooks within @CrewBase classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew
|
||||
from crewai.hooks import (
|
||||
LLMCallHookContext,
|
||||
ToolCallHookContext,
|
||||
before_llm_call,
|
||||
before_tool_call,
|
||||
get_before_llm_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
)
|
||||
from crewai.project import CrewBase, agent, crew
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_hooks():
|
||||
"""Clear global hooks before and after each test."""
|
||||
from crewai.hooks import llm_hooks, tool_hooks
|
||||
|
||||
# Store original hooks
|
||||
original_before_llm = llm_hooks._before_llm_call_hooks.copy()
|
||||
original_before_tool = tool_hooks._before_tool_call_hooks.copy()
|
||||
|
||||
# Clear hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
llm_hooks._before_llm_call_hooks.extend(original_before_llm)
|
||||
tool_hooks._before_tool_call_hooks.extend(original_before_tool)
|
||||
|
||||
|
||||
class TestCrewScopedHooks:
|
||||
"""Test hooks defined as methods within @CrewBase classes."""
|
||||
|
||||
def test_crew_scoped_hook_is_registered_on_instance_creation(self):
|
||||
"""Test that crew-scoped hooks are registered when crew instance is created."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
pass
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Check hooks before instance creation
|
||||
hooks_before = get_before_llm_call_hooks()
|
||||
initial_count = len(hooks_before)
|
||||
|
||||
# Create instance - should register the hook
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Check hooks after instance creation
|
||||
hooks_after = get_before_llm_call_hooks()
|
||||
|
||||
# Should have one more hook registered
|
||||
assert len(hooks_after) == initial_count + 1
|
||||
|
||||
def test_crew_scoped_hook_has_access_to_self(self):
|
||||
"""Test that crew-scoped hooks can access self and instance variables."""
|
||||
execution_log = []
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
def __init__(self):
|
||||
self.crew_name = "TestCrew"
|
||||
self.call_count = 0
|
||||
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
# Can access self
|
||||
self.call_count += 1
|
||||
execution_log.append(f"{self.crew_name}:{self.call_count}")
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Get the registered hook
|
||||
hooks = get_before_llm_call_hooks()
|
||||
crew_hook = hooks[-1] # Last registered hook
|
||||
|
||||
# Create mock context
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Test")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Execute hook multiple times
|
||||
crew_hook(context)
|
||||
crew_hook(context)
|
||||
|
||||
# Verify hook accessed self and modified instance state
|
||||
assert len(execution_log) == 2
|
||||
assert execution_log[0] == "TestCrew:1"
|
||||
assert execution_log[1] == "TestCrew:2"
|
||||
assert crew_instance.call_count == 2
|
||||
|
||||
def test_multiple_crews_have_isolated_hooks(self):
|
||||
"""Test that different crew instances have isolated hooks."""
|
||||
crew1_executions = []
|
||||
crew2_executions = []
|
||||
|
||||
@CrewBase
|
||||
class Crew1:
|
||||
@before_llm_call
|
||||
def crew1_hook(self, context):
|
||||
crew1_executions.append("crew1")
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
@CrewBase
|
||||
class Crew2:
|
||||
@before_llm_call
|
||||
def crew2_hook(self, context):
|
||||
crew2_executions.append("crew2")
|
||||
|
||||
@agent
|
||||
def analyst(self):
|
||||
return Agent(role="Analyst", goal="Analyze", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create both instances
|
||||
instance1 = Crew1()
|
||||
instance2 = Crew2()
|
||||
|
||||
# Both hooks should be registered
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) >= 2
|
||||
|
||||
# Create mock context
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Test")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Execute all hooks
|
||||
for hook in hooks:
|
||||
hook(context)
|
||||
|
||||
# Both hooks should have executed
|
||||
assert "crew1" in crew1_executions
|
||||
assert "crew2" in crew2_executions
|
||||
|
||||
def test_crew_scoped_hook_with_filters(self):
|
||||
"""Test that filtered crew-scoped hooks work correctly."""
|
||||
execution_log = []
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_tool_call(tools=["delete_file"])
|
||||
def filtered_hook(self, context):
|
||||
execution_log.append(f"filtered:{context.tool_name}")
|
||||
return None
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Get registered hooks
|
||||
hooks = get_before_tool_call_hooks()
|
||||
crew_hook = hooks[-1] # Last registered
|
||||
|
||||
# Test with matching tool
|
||||
mock_tool = Mock()
|
||||
context1 = ToolCallHookContext(
|
||||
tool_name="delete_file", tool_input={}, tool=mock_tool
|
||||
)
|
||||
crew_hook(context1)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "filtered:delete_file"
|
||||
|
||||
# Test with non-matching tool
|
||||
context2 = ToolCallHookContext(
|
||||
tool_name="read_file", tool_input={}, tool=mock_tool
|
||||
)
|
||||
crew_hook(context2)
|
||||
|
||||
# Should still be 1 (filtered hook didn't run)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
def test_crew_scoped_hook_no_double_registration(self):
|
||||
"""Test that crew-scoped hooks are not registered twice."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
pass
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Get initial hook count
|
||||
initial_hooks = len(get_before_llm_call_hooks())
|
||||
|
||||
# Create first instance
|
||||
instance1 = TestCrew()
|
||||
|
||||
# Should add 1 hook
|
||||
hooks_after_first = get_before_llm_call_hooks()
|
||||
assert len(hooks_after_first) == initial_hooks + 1
|
||||
|
||||
# Create second instance
|
||||
instance2 = TestCrew()
|
||||
|
||||
# Should add another hook (one per instance)
|
||||
hooks_after_second = get_before_llm_call_hooks()
|
||||
assert len(hooks_after_second) == initial_hooks + 2
|
||||
|
||||
def test_crew_scoped_hook_method_signature(self):
|
||||
"""Test that crew-scoped hooks have correct signature (self + context)."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
def __init__(self):
|
||||
self.test_value = "test"
|
||||
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
# Should be able to access both self and context
|
||||
return f"{self.test_value}:{context.iterations}"
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Verify the hook method has is_before_llm_call_hook marker
|
||||
assert hasattr(crew_instance.my_hook, "__func__")
|
||||
hook_func = crew_instance.my_hook.__func__
|
||||
assert hasattr(hook_func, "is_before_llm_call_hook")
|
||||
assert hook_func.is_before_llm_call_hook is True
|
||||
|
||||
def test_crew_scoped_with_agent_filter(self):
|
||||
"""Test crew-scoped hooks with agent filters."""
|
||||
execution_log = []
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call(agents=["Researcher"])
|
||||
def filtered_hook(self, context):
|
||||
execution_log.append(context.agent.role)
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Get hooks
|
||||
hooks = get_before_llm_call_hooks()
|
||||
crew_hook = hooks[-1]
|
||||
|
||||
# Test with matching agent
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Researcher")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context1 = LLMCallHookContext(executor=mock_executor)
|
||||
crew_hook(context1)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "Researcher"
|
||||
|
||||
# Test with non-matching agent
|
||||
mock_executor.agent.role = "Analyst"
|
||||
context2 = LLMCallHookContext(executor=mock_executor)
|
||||
crew_hook(context2)
|
||||
|
||||
# Should still be 1 (filtered out)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
|
||||
class TestCrewScopedHookAttributes:
|
||||
"""Test that crew-scoped hooks have correct attributes set."""
|
||||
|
||||
def test_hook_marker_attribute_is_set(self):
|
||||
"""Test that decorator sets marker attribute on method."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
pass
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Check the unbound method has the marker
|
||||
assert hasattr(TestCrew.__dict__["my_hook"], "is_before_llm_call_hook")
|
||||
assert TestCrew.__dict__["my_hook"].is_before_llm_call_hook is True
|
||||
|
||||
def test_filter_attributes_are_preserved(self):
|
||||
"""Test that filter attributes are preserved on methods."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_tool_call(tools=["delete_file"], agents=["Dev"])
|
||||
def filtered_hook(self, context):
|
||||
return None
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Check filter attributes are set
|
||||
hook_method = TestCrew.__dict__["filtered_hook"]
|
||||
assert hasattr(hook_method, "is_before_tool_call_hook")
|
||||
assert hasattr(hook_method, "_filter_tools")
|
||||
assert hasattr(hook_method, "_filter_agents")
|
||||
assert hook_method._filter_tools == ["delete_file"]
|
||||
assert hook_method._filter_agents == ["Dev"]
|
||||
|
||||
def test_registered_hooks_tracked_on_instance(self):
|
||||
"""Test that registered hooks are tracked on the crew instance."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call
|
||||
def llm_hook(self, context):
|
||||
pass
|
||||
|
||||
@before_tool_call
|
||||
def tool_hook(self, context):
|
||||
return None
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
|
||||
# Check that hooks are tracked
|
||||
assert hasattr(crew_instance, "_registered_hook_functions")
|
||||
assert isinstance(crew_instance._registered_hook_functions, list)
|
||||
assert len(crew_instance._registered_hook_functions) == 2
|
||||
|
||||
# Check hook types
|
||||
hook_types = [ht for ht, _ in crew_instance._registered_hook_functions]
|
||||
assert "before_llm_call" in hook_types
|
||||
assert "before_tool_call" in hook_types
|
||||
|
||||
|
||||
class TestCrewScopedHookExecution:
|
||||
"""Test execution behavior of crew-scoped hooks."""
|
||||
|
||||
def test_crew_hook_executes_with_bound_self(self):
|
||||
"""Test that crew-scoped hook executes with self properly bound."""
|
||||
execution_log = []
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
def __init__(self):
|
||||
self.instance_id = id(self)
|
||||
|
||||
@before_llm_call
|
||||
def my_hook(self, context):
|
||||
# Should have access to self
|
||||
execution_log.append(self.instance_id)
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
expected_id = crew_instance.instance_id
|
||||
|
||||
# Get and execute hook
|
||||
hooks = get_before_llm_call_hooks()
|
||||
crew_hook = hooks[-1]
|
||||
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Test")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Execute hook
|
||||
crew_hook(context)
|
||||
|
||||
# Verify it had access to self
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == expected_id
|
||||
|
||||
def test_crew_hook_can_modify_instance_state(self):
|
||||
"""Test that crew-scoped hooks can modify instance variables."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
def __init__(self):
|
||||
self.counter = 0
|
||||
|
||||
@before_tool_call
|
||||
def increment_counter(self, context):
|
||||
self.counter += 1
|
||||
return None
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create instance
|
||||
crew_instance = TestCrew()
|
||||
assert crew_instance.counter == 0
|
||||
|
||||
# Get and execute hook
|
||||
hooks = get_before_tool_call_hooks()
|
||||
crew_hook = hooks[-1]
|
||||
|
||||
mock_tool = Mock()
|
||||
context = ToolCallHookContext(tool_name="test", tool_input={}, tool=mock_tool)
|
||||
|
||||
# Execute hook 3 times
|
||||
crew_hook(context)
|
||||
crew_hook(context)
|
||||
crew_hook(context)
|
||||
|
||||
# Verify counter was incremented
|
||||
assert crew_instance.counter == 3
|
||||
|
||||
def test_multiple_instances_maintain_separate_state(self):
|
||||
"""Test that multiple instances of the same crew maintain separate state."""
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
def __init__(self):
|
||||
self.call_count = 0
|
||||
|
||||
@before_llm_call
|
||||
def count_calls(self, context):
|
||||
self.call_count += 1
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Create two instances
|
||||
instance1 = TestCrew()
|
||||
instance2 = TestCrew()
|
||||
|
||||
# Get all hooks (should include hooks from both instances)
|
||||
all_hooks = get_before_llm_call_hooks()
|
||||
|
||||
# Find hooks for each instance (last 2 registered)
|
||||
hook1 = all_hooks[-2]
|
||||
hook2 = all_hooks[-1]
|
||||
|
||||
# Create mock context
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Test")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Execute first hook twice
|
||||
hook1(context)
|
||||
hook1(context)
|
||||
|
||||
# Execute second hook once
|
||||
hook2(context)
|
||||
|
||||
# Each instance should have independent state
|
||||
# Note: We can't easily verify which hook belongs to which instance
|
||||
# in this test without more introspection, but the fact that it doesn't
|
||||
# crash and hooks can maintain state proves isolation works
|
||||
|
||||
|
||||
class TestSignatureDetection:
|
||||
"""Test that signature detection correctly identifies methods vs functions."""
|
||||
|
||||
def test_method_signature_detected(self):
|
||||
"""Test that methods with 'self' parameter are detected."""
|
||||
import inspect
|
||||
|
||||
@CrewBase
|
||||
class TestCrew:
|
||||
@before_llm_call
|
||||
def method_hook(self, context):
|
||||
pass
|
||||
|
||||
@agent
|
||||
def researcher(self):
|
||||
return Agent(role="Researcher", goal="Research", backstory="Expert")
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
return Crew(agents=self.agents, tasks=[], verbose=False)
|
||||
|
||||
# Check that method has self parameter
|
||||
method = TestCrew.__dict__["method_hook"]
|
||||
sig = inspect.signature(method)
|
||||
params = list(sig.parameters.keys())
|
||||
assert params[0] == "self"
|
||||
assert len(params) == 2 # self + context
|
||||
|
||||
def test_standalone_function_signature_detected(self):
|
||||
"""Test that standalone functions without 'self' are detected."""
|
||||
import inspect
|
||||
|
||||
@before_llm_call
|
||||
def standalone_hook(context):
|
||||
pass
|
||||
|
||||
# Should have only context parameter (no self)
|
||||
sig = inspect.signature(standalone_hook)
|
||||
params = list(sig.parameters.keys())
|
||||
assert "self" not in params
|
||||
assert len(params) == 1 # Just context
|
||||
|
||||
# Should be registered
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) >= 1
|
||||
335
lib/crewai/tests/hooks/test_decorators.py
Normal file
335
lib/crewai/tests/hooks/test_decorators.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Tests for decorator-based hook registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.hooks import (
|
||||
after_llm_call,
|
||||
after_tool_call,
|
||||
before_llm_call,
|
||||
before_tool_call,
|
||||
get_after_llm_call_hooks,
|
||||
get_after_tool_call_hooks,
|
||||
get_before_llm_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
)
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_hooks():
|
||||
"""Clear global hooks before and after each test."""
|
||||
from crewai.hooks import llm_hooks, tool_hooks
|
||||
|
||||
# Store original hooks
|
||||
original_before_llm = llm_hooks._before_llm_call_hooks.copy()
|
||||
original_after_llm = llm_hooks._after_llm_call_hooks.copy()
|
||||
original_before_tool = tool_hooks._before_tool_call_hooks.copy()
|
||||
original_after_tool = tool_hooks._after_tool_call_hooks.copy()
|
||||
|
||||
# Clear hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
llm_hooks._after_llm_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
tool_hooks._after_tool_call_hooks.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
llm_hooks._after_llm_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
tool_hooks._after_tool_call_hooks.clear()
|
||||
llm_hooks._before_llm_call_hooks.extend(original_before_llm)
|
||||
llm_hooks._after_llm_call_hooks.extend(original_after_llm)
|
||||
tool_hooks._before_tool_call_hooks.extend(original_before_tool)
|
||||
tool_hooks._after_tool_call_hooks.extend(original_after_tool)
|
||||
|
||||
|
||||
class TestLLMHookDecorators:
|
||||
"""Test LLM hook decorators."""
|
||||
|
||||
def test_before_llm_call_decorator_registers_hook(self):
|
||||
"""Test that @before_llm_call decorator registers the hook."""
|
||||
|
||||
@before_llm_call
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
def test_after_llm_call_decorator_registers_hook(self):
|
||||
"""Test that @after_llm_call decorator registers the hook."""
|
||||
|
||||
@after_llm_call
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
hooks = get_after_llm_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
def test_decorated_hook_executes_correctly(self):
|
||||
"""Test that decorated hook executes and modifies behavior."""
|
||||
execution_log = []
|
||||
|
||||
@before_llm_call
|
||||
def test_hook(context):
|
||||
execution_log.append("executed")
|
||||
|
||||
# Create mock context
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Test")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Execute the hook
|
||||
hooks = get_before_llm_call_hooks()
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "executed"
|
||||
|
||||
def test_before_llm_call_with_agent_filter(self):
|
||||
"""Test that agent filter works correctly."""
|
||||
execution_log = []
|
||||
|
||||
@before_llm_call(agents=["Researcher"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.agent.role)
|
||||
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
# Test with matching agent
|
||||
mock_executor = Mock()
|
||||
mock_executor.messages = []
|
||||
mock_executor.agent = Mock(role="Researcher")
|
||||
mock_executor.task = Mock()
|
||||
mock_executor.crew = Mock()
|
||||
mock_executor.llm = Mock()
|
||||
mock_executor.iterations = 0
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "Researcher"
|
||||
|
||||
# Test with non-matching agent
|
||||
mock_executor.agent.role = "Analyst"
|
||||
context2 = LLMCallHookContext(executor=mock_executor)
|
||||
hooks[0](context2)
|
||||
|
||||
# Should still be 1 (hook didn't execute)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
|
||||
class TestToolHookDecorators:
|
||||
"""Test tool hook decorators."""
|
||||
|
||||
def test_before_tool_call_decorator_registers_hook(self):
|
||||
"""Test that @before_tool_call decorator registers the hook."""
|
||||
|
||||
@before_tool_call
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
def test_after_tool_call_decorator_registers_hook(self):
|
||||
"""Test that @after_tool_call decorator registers the hook."""
|
||||
|
||||
@after_tool_call
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
hooks = get_after_tool_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
def test_before_tool_call_with_tool_filter(self):
|
||||
"""Test that tool filter works correctly."""
|
||||
execution_log = []
|
||||
|
||||
@before_tool_call(tools=["delete_file", "execute_code"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(context.tool_name)
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 1
|
||||
|
||||
# Test with matching tool
|
||||
mock_tool = Mock()
|
||||
context = ToolCallHookContext(
|
||||
tool_name="delete_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "delete_file"
|
||||
|
||||
# Test with non-matching tool
|
||||
context2 = ToolCallHookContext(
|
||||
tool_name="read_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
hooks[0](context2)
|
||||
|
||||
# Should still be 1 (hook didn't execute for read_file)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
def test_before_tool_call_with_combined_filters(self):
|
||||
"""Test that combined tool and agent filters work."""
|
||||
execution_log = []
|
||||
|
||||
@before_tool_call(tools=["write_file"], agents=["Developer"])
|
||||
def filtered_hook(context):
|
||||
execution_log.append(f"{context.tool_name}-{context.agent.role}")
|
||||
return None
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
mock_agent = Mock(role="Developer")
|
||||
|
||||
# Test with both matching
|
||||
context = ToolCallHookContext(
|
||||
tool_name="write_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
agent=mock_agent,
|
||||
)
|
||||
hooks[0](context)
|
||||
|
||||
assert len(execution_log) == 1
|
||||
assert execution_log[0] == "write_file-Developer"
|
||||
|
||||
# Test with tool matching but agent not
|
||||
mock_agent.role = "Researcher"
|
||||
context2 = ToolCallHookContext(
|
||||
tool_name="write_file",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
agent=mock_agent,
|
||||
)
|
||||
hooks[0](context2)
|
||||
|
||||
# Should still be 1 (hook didn't execute)
|
||||
assert len(execution_log) == 1
|
||||
|
||||
def test_after_tool_call_with_filter(self):
|
||||
"""Test that after_tool_call decorator with filter works."""
|
||||
|
||||
@after_tool_call(tools=["web_search"])
|
||||
def filtered_hook(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result.upper()
|
||||
return None
|
||||
|
||||
hooks = get_after_tool_call_hooks()
|
||||
mock_tool = Mock()
|
||||
|
||||
# Test with matching tool
|
||||
context = ToolCallHookContext(
|
||||
tool_name="web_search",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
tool_result="result",
|
||||
)
|
||||
result = hooks[0](context)
|
||||
|
||||
assert result == "RESULT"
|
||||
|
||||
# Test with non-matching tool
|
||||
context2 = ToolCallHookContext(
|
||||
tool_name="other_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
tool_result="result",
|
||||
)
|
||||
result2 = hooks[0](context2)
|
||||
|
||||
assert result2 is None # Hook didn't run, returns None
|
||||
|
||||
|
||||
class TestDecoratorAttributes:
|
||||
"""Test that decorators set proper attributes on functions."""
|
||||
|
||||
def test_before_llm_call_sets_attribute(self):
|
||||
"""Test that decorator sets is_before_llm_call_hook attribute."""
|
||||
|
||||
@before_llm_call
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
assert hasattr(test_hook, "is_before_llm_call_hook")
|
||||
assert test_hook.is_before_llm_call_hook is True
|
||||
|
||||
def test_before_tool_call_sets_attributes_with_filters(self):
|
||||
"""Test that decorator with filters sets filter attributes."""
|
||||
|
||||
@before_tool_call(tools=["delete_file"], agents=["Dev"])
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
assert hasattr(test_hook, "is_before_tool_call_hook")
|
||||
assert test_hook.is_before_tool_call_hook is True
|
||||
assert hasattr(test_hook, "_filter_tools")
|
||||
assert test_hook._filter_tools == ["delete_file"]
|
||||
assert hasattr(test_hook, "_filter_agents")
|
||||
assert test_hook._filter_agents == ["Dev"]
|
||||
|
||||
|
||||
class TestMultipleDecorators:
|
||||
"""Test using multiple decorators together."""
|
||||
|
||||
def test_multiple_decorators_all_register(self):
|
||||
"""Test that multiple decorated functions all register."""
|
||||
|
||||
@before_llm_call
|
||||
def hook1(context):
|
||||
pass
|
||||
|
||||
@before_llm_call
|
||||
def hook2(context):
|
||||
pass
|
||||
|
||||
@after_llm_call
|
||||
def hook3(context):
|
||||
return None
|
||||
|
||||
before_hooks = get_before_llm_call_hooks()
|
||||
after_hooks = get_after_llm_call_hooks()
|
||||
|
||||
assert len(before_hooks) == 2
|
||||
assert len(after_hooks) == 1
|
||||
|
||||
def test_decorator_and_manual_registration_work_together(self):
|
||||
"""Test that decorators and manual registration can be mixed."""
|
||||
from crewai.hooks import register_before_tool_call_hook
|
||||
|
||||
@before_tool_call
|
||||
def decorated_hook(context):
|
||||
return None
|
||||
|
||||
def manual_hook(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(manual_hook)
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
395
lib/crewai/tests/hooks/test_human_approval.py
Normal file
395
lib/crewai/tests/hooks/test_human_approval.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""Tests for human approval functionality in hooks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from crewai.hooks.llm_hooks import LLMCallHookContext
|
||||
from crewai.hooks.tool_hooks import ToolCallHookContext
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_executor():
|
||||
"""Create a mock executor for LLM hook context."""
|
||||
executor = Mock()
|
||||
executor.messages = [{"role": "system", "content": "Test message"}]
|
||||
executor.agent = Mock(role="Test Agent")
|
||||
executor.task = Mock(description="Test Task")
|
||||
executor.crew = Mock()
|
||||
executor.llm = Mock()
|
||||
executor.iterations = 0
|
||||
return executor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool for tool hook context."""
|
||||
tool = Mock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "Test tool description"
|
||||
return tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a mock agent."""
|
||||
agent = Mock()
|
||||
agent.role = "Test Agent"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
"""Create a mock task."""
|
||||
task = Mock()
|
||||
task.description = "Test task"
|
||||
return task
|
||||
|
||||
|
||||
class TestLLMHookHumanInput:
|
||||
"""Test request_human_input() on LLMCallHookContext."""
|
||||
|
||||
@patch("builtins.input", return_value="test response")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_request_human_input_returns_user_response(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that request_human_input returns the user's input."""
|
||||
# Setup mock formatter
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
response = context.request_human_input(
|
||||
prompt="Test prompt", default_message="Test default message"
|
||||
)
|
||||
|
||||
assert response == "test response"
|
||||
mock_input.assert_called_once()
|
||||
|
||||
@patch("builtins.input", return_value="")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_request_human_input_returns_empty_string_on_enter(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that pressing Enter returns empty string."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
response = context.request_human_input(prompt="Test")
|
||||
|
||||
assert response == ""
|
||||
mock_input.assert_called_once()
|
||||
|
||||
@patch("builtins.input", return_value="test")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_request_human_input_pauses_and_resumes_live_updates(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that live updates are paused and resumed."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
context.request_human_input(prompt="Test")
|
||||
|
||||
# Verify pause was called
|
||||
mock_formatter.pause_live_updates.assert_called_once()
|
||||
|
||||
# Verify resume was called
|
||||
mock_formatter.resume_live_updates.assert_called_once()
|
||||
|
||||
@patch("builtins.input", side_effect=Exception("Input error"))
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_request_human_input_resumes_on_exception(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that live updates are resumed even if input raises exception."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
with pytest.raises(Exception, match="Input error"):
|
||||
context.request_human_input(prompt="Test")
|
||||
|
||||
# Verify resume was still called (in finally block)
|
||||
mock_formatter.resume_live_updates.assert_called_once()
|
||||
|
||||
@patch("builtins.input", return_value=" test response ")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_request_human_input_strips_whitespace(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that user input is stripped of leading/trailing whitespace."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
response = context.request_human_input(prompt="Test")
|
||||
|
||||
assert response == "test response" # Whitespace stripped
|
||||
|
||||
|
||||
class TestToolHookHumanInput:
|
||||
"""Test request_human_input() on ToolCallHookContext."""
|
||||
|
||||
@patch("builtins.input", return_value="approve")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_request_human_input_returns_user_response(
|
||||
self, mock_event_listener, mock_input, mock_tool, mock_agent, mock_task
|
||||
):
|
||||
"""Test that request_human_input returns the user's input."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={"arg": "value"},
|
||||
tool=mock_tool,
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
)
|
||||
|
||||
response = context.request_human_input(
|
||||
prompt="Approve this tool?", default_message="Type 'approve':"
|
||||
)
|
||||
|
||||
assert response == "approve"
|
||||
mock_input.assert_called_once()
|
||||
|
||||
@patch("builtins.input", return_value="")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_request_human_input_handles_empty_input(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that empty input (Enter key) is handled correctly."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
response = context.request_human_input(prompt="Test")
|
||||
|
||||
assert response == ""
|
||||
|
||||
@patch("builtins.input", return_value="test")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_request_human_input_pauses_and_resumes(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that live updates are properly paused and resumed."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
context.request_human_input(prompt="Test")
|
||||
|
||||
mock_formatter.pause_live_updates.assert_called_once()
|
||||
mock_formatter.resume_live_updates.assert_called_once()
|
||||
|
||||
@patch("builtins.input", side_effect=KeyboardInterrupt)
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_request_human_input_resumes_on_keyboard_interrupt(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that live updates are resumed even on keyboard interrupt."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
context.request_human_input(prompt="Test")
|
||||
|
||||
# Verify resume was still called (in finally block)
|
||||
mock_formatter.resume_live_updates.assert_called_once()
|
||||
|
||||
|
||||
class TestApprovalHookIntegration:
|
||||
"""Test integration scenarios with approval hooks."""
|
||||
|
||||
@patch("builtins.input", return_value="approve")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_approval_hook_allows_execution(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that approval hook allows execution when approved."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
def approval_hook(context: ToolCallHookContext) -> bool | None:
|
||||
response = context.request_human_input(
|
||||
prompt="Approve?", default_message="Type 'approve':"
|
||||
)
|
||||
return None if response == "approve" else False
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
result = approval_hook(context)
|
||||
|
||||
assert result is None # Allowed
|
||||
assert mock_input.called
|
||||
|
||||
@patch("builtins.input", return_value="deny")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_approval_hook_blocks_execution(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that approval hook blocks execution when denied."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
def approval_hook(context: ToolCallHookContext) -> bool | None:
|
||||
response = context.request_human_input(
|
||||
prompt="Approve?", default_message="Type 'approve':"
|
||||
)
|
||||
return None if response == "approve" else False
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
result = approval_hook(context)
|
||||
|
||||
assert result is False # Blocked
|
||||
assert mock_input.called
|
||||
|
||||
@patch("builtins.input", return_value="modified result")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_review_hook_modifies_result(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that review hook can modify tool results."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
def review_hook(context: ToolCallHookContext) -> str | None:
|
||||
response = context.request_human_input(
|
||||
prompt="Review result",
|
||||
default_message="Press Enter to keep, or provide modified version:",
|
||||
)
|
||||
return response if response else None
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
tool_result="original result",
|
||||
)
|
||||
|
||||
modified_result = review_hook(context)
|
||||
|
||||
assert modified_result == "modified result"
|
||||
assert mock_input.called
|
||||
|
||||
@patch("builtins.input", return_value="")
|
||||
@patch("crewai.hooks.tool_hooks.event_listener")
|
||||
def test_review_hook_keeps_original_on_enter(
|
||||
self, mock_event_listener, mock_input, mock_tool
|
||||
):
|
||||
"""Test that pressing Enter keeps original result."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
def review_hook(context: ToolCallHookContext) -> str | None:
|
||||
response = context.request_human_input(
|
||||
prompt="Review result", default_message="Press Enter to keep:"
|
||||
)
|
||||
return response if response else None
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input={},
|
||||
tool=mock_tool,
|
||||
tool_result="original result",
|
||||
)
|
||||
|
||||
modified_result = review_hook(context)
|
||||
|
||||
assert modified_result is None # Keep original
|
||||
|
||||
|
||||
class TestCostControlApproval:
|
||||
"""Test cost control approval hook scenarios."""
|
||||
|
||||
@patch("builtins.input", return_value="yes")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_cost_control_allows_when_approved(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that expensive calls are allowed when approved."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
# Set high iteration count
|
||||
mock_executor.iterations = 10
|
||||
|
||||
def cost_control_hook(context: LLMCallHookContext) -> None:
|
||||
if context.iterations > 5:
|
||||
response = context.request_human_input(
|
||||
prompt=f"Iteration {context.iterations} - expensive call",
|
||||
default_message="Type 'yes' to continue:",
|
||||
)
|
||||
if response.lower() != "yes":
|
||||
print("Call blocked")
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Should not raise exception and should call input
|
||||
cost_control_hook(context)
|
||||
assert mock_input.called
|
||||
|
||||
@patch("builtins.input", return_value="no")
|
||||
@patch("crewai.hooks.llm_hooks.event_listener")
|
||||
def test_cost_control_logs_when_denied(
|
||||
self, mock_event_listener, mock_input, mock_executor
|
||||
):
|
||||
"""Test that denied calls are logged."""
|
||||
mock_formatter = Mock()
|
||||
mock_event_listener.formatter = mock_formatter
|
||||
|
||||
mock_executor.iterations = 10
|
||||
|
||||
messages_logged = []
|
||||
|
||||
def cost_control_hook(context: LLMCallHookContext) -> None:
|
||||
if context.iterations > 5:
|
||||
response = context.request_human_input(
|
||||
prompt=f"Iteration {context.iterations}",
|
||||
default_message="Type 'yes' to continue:",
|
||||
)
|
||||
if response.lower() != "yes":
|
||||
messages_logged.append("blocked")
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
cost_control_hook(context)
|
||||
|
||||
assert len(messages_logged) == 1
|
||||
assert messages_logged[0] == "blocked"
|
||||
311
lib/crewai/tests/hooks/test_llm_hooks.py
Normal file
311
lib/crewai/tests/hooks/test_llm_hooks.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""Unit tests for LLM hooks functionality."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from crewai.hooks import clear_all_llm_call_hooks, unregister_after_llm_call_hook, unregister_before_llm_call_hook
|
||||
import pytest
|
||||
|
||||
from crewai.hooks.llm_hooks import (
|
||||
LLMCallHookContext,
|
||||
get_after_llm_call_hooks,
|
||||
get_before_llm_call_hooks,
|
||||
register_after_llm_call_hook,
|
||||
register_before_llm_call_hook,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_executor():
|
||||
"""Create a mock executor for testing."""
|
||||
executor = Mock()
|
||||
executor.messages = [{"role": "system", "content": "Test message"}]
|
||||
executor.agent = Mock(role="Test Agent")
|
||||
executor.task = Mock(description="Test Task")
|
||||
executor.crew = Mock()
|
||||
executor.llm = Mock()
|
||||
executor.iterations = 0
|
||||
return executor
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_hooks():
|
||||
"""Clear global hooks before and after each test."""
|
||||
# Import the private variables to clear them
|
||||
from crewai.hooks import llm_hooks
|
||||
|
||||
# Store original hooks
|
||||
original_before = llm_hooks._before_llm_call_hooks.copy()
|
||||
original_after = llm_hooks._after_llm_call_hooks.copy()
|
||||
|
||||
# Clear hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
llm_hooks._after_llm_call_hooks.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original hooks
|
||||
llm_hooks._before_llm_call_hooks.clear()
|
||||
llm_hooks._after_llm_call_hooks.clear()
|
||||
llm_hooks._before_llm_call_hooks.extend(original_before)
|
||||
llm_hooks._after_llm_call_hooks.extend(original_after)
|
||||
|
||||
|
||||
class TestLLMCallHookContext:
|
||||
"""Test LLMCallHookContext initialization and attributes."""
|
||||
|
||||
def test_context_initialization(self, mock_executor):
|
||||
"""Test that context is initialized correctly with executor."""
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
assert context.executor == mock_executor
|
||||
assert context.messages == mock_executor.messages
|
||||
assert context.agent == mock_executor.agent
|
||||
assert context.task == mock_executor.task
|
||||
assert context.crew == mock_executor.crew
|
||||
assert context.llm == mock_executor.llm
|
||||
assert context.iterations == mock_executor.iterations
|
||||
assert context.response is None
|
||||
|
||||
def test_context_with_response(self, mock_executor):
|
||||
"""Test that context includes response when provided."""
|
||||
test_response = "Test LLM response"
|
||||
context = LLMCallHookContext(executor=mock_executor, response=test_response)
|
||||
|
||||
assert context.response == test_response
|
||||
|
||||
def test_messages_are_mutable_reference(self, mock_executor):
|
||||
"""Test that modifying context.messages modifies executor.messages."""
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
|
||||
# Add a message through context
|
||||
new_message = {"role": "user", "content": "New message"}
|
||||
context.messages.append(new_message)
|
||||
|
||||
# Check that executor.messages is also modified
|
||||
assert new_message in mock_executor.messages
|
||||
assert len(mock_executor.messages) == 2
|
||||
|
||||
|
||||
class TestBeforeLLMCallHooks:
|
||||
"""Test before_llm_call hook registration and execution."""
|
||||
|
||||
def test_register_before_hook(self):
|
||||
"""Test that before hooks are registered correctly."""
|
||||
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_llm_call_hook(test_hook)
|
||||
hooks = get_before_llm_call_hooks()
|
||||
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0] == test_hook
|
||||
|
||||
def test_multiple_before_hooks(self):
|
||||
"""Test that multiple before hooks can be registered."""
|
||||
|
||||
def hook1(context):
|
||||
pass
|
||||
|
||||
def hook2(context):
|
||||
pass
|
||||
|
||||
register_before_llm_call_hook(hook1)
|
||||
register_before_llm_call_hook(hook2)
|
||||
hooks = get_before_llm_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
assert hook1 in hooks
|
||||
assert hook2 in hooks
|
||||
|
||||
def test_before_hook_can_modify_messages(self, mock_executor):
|
||||
"""Test that before hooks can modify messages in-place."""
|
||||
|
||||
def add_message_hook(context):
|
||||
context.messages.append({"role": "system", "content": "Added by hook"})
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
add_message_hook(context)
|
||||
|
||||
assert len(context.messages) == 2
|
||||
assert context.messages[1]["content"] == "Added by hook"
|
||||
|
||||
def test_get_before_hooks_returns_copy(self):
|
||||
"""Test that get_before_llm_call_hooks returns a copy."""
|
||||
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_llm_call_hook(test_hook)
|
||||
hooks1 = get_before_llm_call_hooks()
|
||||
hooks2 = get_before_llm_call_hooks()
|
||||
|
||||
# They should be equal but not the same object
|
||||
assert hooks1 == hooks2
|
||||
assert hooks1 is not hooks2
|
||||
|
||||
|
||||
class TestAfterLLMCallHooks:
|
||||
"""Test after_llm_call hook registration and execution."""
|
||||
|
||||
def test_register_after_hook(self):
|
||||
"""Test that after hooks are registered correctly."""
|
||||
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(test_hook)
|
||||
hooks = get_after_llm_call_hooks()
|
||||
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0] == test_hook
|
||||
|
||||
def test_multiple_after_hooks(self):
|
||||
"""Test that multiple after hooks can be registered."""
|
||||
|
||||
def hook1(context):
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(hook1)
|
||||
register_after_llm_call_hook(hook2)
|
||||
hooks = get_after_llm_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
assert hook1 in hooks
|
||||
assert hook2 in hooks
|
||||
|
||||
def test_after_hook_can_modify_response(self, mock_executor):
|
||||
"""Test that after hooks can modify the response."""
|
||||
original_response = "Original response"
|
||||
|
||||
def modify_response_hook(context):
|
||||
if context.response:
|
||||
return context.response.replace("Original", "Modified")
|
||||
return None
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor, response=original_response)
|
||||
modified = modify_response_hook(context)
|
||||
|
||||
assert modified == "Modified response"
|
||||
|
||||
def test_after_hook_returns_none_keeps_original(self, mock_executor):
|
||||
"""Test that returning None keeps the original response."""
|
||||
original_response = "Original response"
|
||||
|
||||
def no_change_hook(context):
|
||||
return None
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor, response=original_response)
|
||||
result = no_change_hook(context)
|
||||
|
||||
assert result is None
|
||||
assert context.response == original_response
|
||||
|
||||
def test_get_after_hooks_returns_copy(self):
|
||||
"""Test that get_after_llm_call_hooks returns a copy."""
|
||||
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(test_hook)
|
||||
hooks1 = get_after_llm_call_hooks()
|
||||
hooks2 = get_after_llm_call_hooks()
|
||||
|
||||
# They should be equal but not the same object
|
||||
assert hooks1 == hooks2
|
||||
assert hooks1 is not hooks2
|
||||
|
||||
|
||||
class TestLLMHooksIntegration:
|
||||
"""Test integration scenarios with multiple hooks."""
|
||||
|
||||
def test_multiple_before_hooks_execute_in_order(self, mock_executor):
|
||||
"""Test that multiple before hooks execute in registration order."""
|
||||
execution_order = []
|
||||
|
||||
def hook1(context):
|
||||
execution_order.append(1)
|
||||
|
||||
def hook2(context):
|
||||
execution_order.append(2)
|
||||
|
||||
def hook3(context):
|
||||
execution_order.append(3)
|
||||
|
||||
register_before_llm_call_hook(hook1)
|
||||
register_before_llm_call_hook(hook2)
|
||||
register_before_llm_call_hook(hook3)
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor)
|
||||
hooks = get_before_llm_call_hooks()
|
||||
|
||||
for hook in hooks:
|
||||
hook(context)
|
||||
|
||||
assert execution_order == [1, 2, 3]
|
||||
|
||||
def test_multiple_after_hooks_chain_modifications(self, mock_executor):
|
||||
"""Test that multiple after hooks can chain modifications."""
|
||||
|
||||
def hook1(context):
|
||||
if context.response:
|
||||
return context.response + " [hook1]"
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
if context.response:
|
||||
return context.response + " [hook2]"
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(hook1)
|
||||
register_after_llm_call_hook(hook2)
|
||||
|
||||
context = LLMCallHookContext(executor=mock_executor, response="Original")
|
||||
hooks = get_after_llm_call_hooks()
|
||||
|
||||
# Simulate chaining (how it would be used in practice)
|
||||
result = context.response
|
||||
for hook in hooks:
|
||||
# Update context for next hook
|
||||
context.response = result
|
||||
modified = hook(context)
|
||||
if modified is not None:
|
||||
result = modified
|
||||
|
||||
assert result == "Original [hook1] [hook2]"
|
||||
|
||||
def test_unregister_before_hook(self):
|
||||
"""Test that before hooks can be unregistered."""
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_llm_call_hook(test_hook)
|
||||
unregister_before_llm_call_hook(test_hook)
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
|
||||
def test_unregister_after_hook(self):
|
||||
"""Test that after hooks can be unregistered."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(test_hook)
|
||||
unregister_after_llm_call_hook(test_hook)
|
||||
hooks = get_after_llm_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
|
||||
def test_clear_all_llm_call_hooks(self):
|
||||
"""Test that all llm call hooks can be cleared."""
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_llm_call_hook(test_hook)
|
||||
register_after_llm_call_hook(test_hook)
|
||||
clear_all_llm_call_hooks()
|
||||
hooks = get_before_llm_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
498
lib/crewai/tests/hooks/test_tool_hooks.py
Normal file
498
lib/crewai/tests/hooks/test_tool_hooks.py
Normal file
@@ -0,0 +1,498 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from crewai.hooks import clear_all_tool_call_hooks, unregister_after_tool_call_hook, unregister_before_tool_call_hook
|
||||
import pytest
|
||||
|
||||
from crewai.hooks.tool_hooks import (
|
||||
ToolCallHookContext,
|
||||
get_after_tool_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
register_after_tool_call_hook,
|
||||
register_before_tool_call_hook,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool for testing."""
|
||||
tool = Mock()
|
||||
tool.name = "test_tool"
|
||||
tool.description = "Test tool description"
|
||||
return tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a mock agent for testing."""
|
||||
agent = Mock()
|
||||
agent.role = "Test Agent"
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task():
|
||||
"""Create a mock task for testing."""
|
||||
task = Mock()
|
||||
task.description = "Test task"
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_crew():
|
||||
"""Create a mock crew for testing."""
|
||||
crew = Mock()
|
||||
return crew
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_hooks():
|
||||
"""Clear global hooks before and after each test."""
|
||||
from crewai.hooks import tool_hooks
|
||||
|
||||
# Store original hooks
|
||||
original_before = tool_hooks._before_tool_call_hooks.copy()
|
||||
original_after = tool_hooks._after_tool_call_hooks.copy()
|
||||
|
||||
# Clear hooks
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
tool_hooks._after_tool_call_hooks.clear()
|
||||
|
||||
yield
|
||||
|
||||
# Restore original hooks
|
||||
tool_hooks._before_tool_call_hooks.clear()
|
||||
tool_hooks._after_tool_call_hooks.clear()
|
||||
tool_hooks._before_tool_call_hooks.extend(original_before)
|
||||
tool_hooks._after_tool_call_hooks.extend(original_after)
|
||||
|
||||
|
||||
class TestToolCallHookContext:
|
||||
"""Test ToolCallHookContext initialization and attributes."""
|
||||
|
||||
def test_context_initialization(self, mock_tool, mock_agent, mock_task, mock_crew):
|
||||
"""Test that context is initialized correctly."""
|
||||
tool_input = {"arg1": "value1", "arg2": "value2"}
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
agent=mock_agent,
|
||||
task=mock_task,
|
||||
crew=mock_crew,
|
||||
)
|
||||
|
||||
assert context.tool_name == "test_tool"
|
||||
assert context.tool_input == tool_input
|
||||
assert context.tool == mock_tool
|
||||
assert context.agent == mock_agent
|
||||
assert context.task == mock_task
|
||||
assert context.crew == mock_crew
|
||||
assert context.tool_result is None
|
||||
|
||||
def test_context_with_result(self, mock_tool):
|
||||
"""Test that context includes result when provided."""
|
||||
tool_input = {"arg1": "value1"}
|
||||
tool_result = "Test tool result"
|
||||
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
assert context.tool_result == tool_result
|
||||
|
||||
def test_tool_input_is_mutable_reference(self, mock_tool):
|
||||
"""Test that modifying context.tool_input modifies the original dict."""
|
||||
tool_input = {"arg1": "value1"}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
# Modify through context
|
||||
context.tool_input["arg2"] = "value2"
|
||||
|
||||
# Check that original dict is also modified
|
||||
assert "arg2" in tool_input
|
||||
assert tool_input["arg2"] == "value2"
|
||||
|
||||
|
||||
class TestBeforeToolCallHooks:
|
||||
"""Test before_tool_call hook registration and execution."""
|
||||
|
||||
def test_register_before_hook(self):
|
||||
"""Test that before hooks are registered correctly."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
hooks = get_before_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0] == test_hook
|
||||
|
||||
def test_multiple_before_hooks(self):
|
||||
"""Test that multiple before hooks can be registered."""
|
||||
def hook1(context):
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(hook1)
|
||||
register_before_tool_call_hook(hook2)
|
||||
hooks = get_before_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
assert hook1 in hooks
|
||||
assert hook2 in hooks
|
||||
|
||||
def test_before_hook_can_block_execution(self, mock_tool):
|
||||
"""Test that before hooks can block tool execution."""
|
||||
def block_hook(context):
|
||||
if context.tool_name == "dangerous_tool":
|
||||
return False # Block execution
|
||||
return None # Allow execution
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="dangerous_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
result = block_hook(context)
|
||||
assert result is False
|
||||
|
||||
def test_before_hook_can_allow_execution(self, mock_tool):
|
||||
"""Test that before hooks can explicitly allow execution."""
|
||||
def allow_hook(context):
|
||||
return None # Allow execution
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="safe_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
result = allow_hook(context)
|
||||
assert result is None
|
||||
|
||||
def test_before_hook_can_modify_input(self, mock_tool):
|
||||
"""Test that before hooks can modify tool input in-place."""
|
||||
def modify_input_hook(context):
|
||||
context.tool_input["modified_by_hook"] = True
|
||||
return None
|
||||
|
||||
tool_input = {"arg1": "value1"}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
modify_input_hook(context)
|
||||
|
||||
assert "modified_by_hook" in context.tool_input
|
||||
assert context.tool_input["modified_by_hook"] is True
|
||||
|
||||
def test_get_before_hooks_returns_copy(self):
|
||||
"""Test that get_before_tool_call_hooks returns a copy."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
hooks1 = get_before_tool_call_hooks()
|
||||
hooks2 = get_before_tool_call_hooks()
|
||||
|
||||
# They should be equal but not the same object
|
||||
assert hooks1 == hooks2
|
||||
assert hooks1 is not hooks2
|
||||
|
||||
|
||||
class TestAfterToolCallHooks:
|
||||
"""Test after_tool_call hook registration and execution."""
|
||||
|
||||
def test_register_after_hook(self):
|
||||
"""Test that after hooks are registered correctly."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(test_hook)
|
||||
hooks = get_after_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 1
|
||||
assert hooks[0] == test_hook
|
||||
|
||||
def test_multiple_after_hooks(self):
|
||||
"""Test that multiple after hooks can be registered."""
|
||||
def hook1(context):
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(hook1)
|
||||
register_after_tool_call_hook(hook2)
|
||||
hooks = get_after_tool_call_hooks()
|
||||
|
||||
assert len(hooks) == 2
|
||||
assert hook1 in hooks
|
||||
assert hook2 in hooks
|
||||
|
||||
def test_after_hook_can_modify_result(self, mock_tool):
|
||||
"""Test that after hooks can modify the tool result."""
|
||||
original_result = "Original result"
|
||||
|
||||
def modify_result_hook(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result.replace("Original", "Modified")
|
||||
return None
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result=original_result,
|
||||
)
|
||||
|
||||
modified = modify_result_hook(context)
|
||||
assert modified == "Modified result"
|
||||
|
||||
def test_after_hook_returns_none_keeps_original(self, mock_tool):
|
||||
"""Test that returning None keeps the original result."""
|
||||
original_result = "Original result"
|
||||
|
||||
def no_change_hook(context):
|
||||
return None
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result=original_result,
|
||||
)
|
||||
|
||||
result = no_change_hook(context)
|
||||
|
||||
assert result is None
|
||||
assert context.tool_result == original_result
|
||||
|
||||
def test_get_after_hooks_returns_copy(self):
|
||||
"""Test that get_after_tool_call_hooks returns a copy."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(test_hook)
|
||||
hooks1 = get_after_tool_call_hooks()
|
||||
hooks2 = get_after_tool_call_hooks()
|
||||
|
||||
# They should be equal but not the same object
|
||||
assert hooks1 == hooks2
|
||||
assert hooks1 is not hooks2
|
||||
|
||||
|
||||
class TestToolHooksIntegration:
|
||||
"""Test integration scenarios with multiple hooks."""
|
||||
|
||||
def test_multiple_before_hooks_execute_in_order(self, mock_tool):
|
||||
"""Test that multiple before hooks execute in registration order."""
|
||||
execution_order = []
|
||||
|
||||
def hook1(context):
|
||||
execution_order.append(1)
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
execution_order.append(2)
|
||||
return None
|
||||
|
||||
def hook3(context):
|
||||
execution_order.append(3)
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(hook1)
|
||||
register_before_tool_call_hook(hook2)
|
||||
register_before_tool_call_hook(hook3)
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
for hook in hooks:
|
||||
hook(context)
|
||||
|
||||
assert execution_order == [1, 2, 3]
|
||||
|
||||
def test_first_blocking_hook_stops_execution(self, mock_tool):
|
||||
"""Test that first hook returning False blocks execution."""
|
||||
execution_order = []
|
||||
|
||||
def hook1(context):
|
||||
execution_order.append(1)
|
||||
return None # Allow
|
||||
|
||||
def hook2(context):
|
||||
execution_order.append(2)
|
||||
return False # Block
|
||||
|
||||
def hook3(context):
|
||||
execution_order.append(3)
|
||||
return None # This shouldn't run
|
||||
|
||||
register_before_tool_call_hook(hook1)
|
||||
register_before_tool_call_hook(hook2)
|
||||
register_before_tool_call_hook(hook3)
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
hooks = get_before_tool_call_hooks()
|
||||
blocked = False
|
||||
for hook in hooks:
|
||||
result = hook(context)
|
||||
if result is False:
|
||||
blocked = True
|
||||
break
|
||||
|
||||
assert blocked is True
|
||||
assert execution_order == [1, 2] # hook3 didn't run
|
||||
|
||||
def test_multiple_after_hooks_chain_modifications(self, mock_tool):
|
||||
"""Test that multiple after hooks can chain modifications."""
|
||||
def hook1(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result + " [hook1]"
|
||||
return None
|
||||
|
||||
def hook2(context):
|
||||
if context.tool_result:
|
||||
return context.tool_result + " [hook2]"
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(hook1)
|
||||
register_after_tool_call_hook(hook2)
|
||||
|
||||
tool_input = {}
|
||||
context = ToolCallHookContext(
|
||||
tool_name="test_tool",
|
||||
tool_input=tool_input,
|
||||
tool=mock_tool,
|
||||
tool_result="Original",
|
||||
)
|
||||
|
||||
hooks = get_after_tool_call_hooks()
|
||||
|
||||
# Simulate chaining (how it would be used in practice)
|
||||
result = context.tool_result
|
||||
for hook in hooks:
|
||||
# Update context for next hook
|
||||
context.tool_result = result
|
||||
modified = hook(context)
|
||||
if modified is not None:
|
||||
result = modified
|
||||
|
||||
assert result == "Original [hook1] [hook2]"
|
||||
|
||||
def test_hooks_with_validation_and_sanitization(self, mock_tool):
|
||||
"""Test a realistic scenario with validation and sanitization hooks."""
|
||||
# Validation hook (before)
|
||||
def validate_file_path(context):
|
||||
if context.tool_name == "write_file":
|
||||
file_path = context.tool_input.get("file_path", "")
|
||||
if ".env" in file_path:
|
||||
return False # Block sensitive files
|
||||
return None
|
||||
|
||||
# Sanitization hook (after)
|
||||
def sanitize_secrets(context):
|
||||
if context.tool_result and "SECRET_KEY" in context.tool_result:
|
||||
return context.tool_result.replace("SECRET_KEY=abc123", "SECRET_KEY=[REDACTED]")
|
||||
return None
|
||||
|
||||
register_before_tool_call_hook(validate_file_path)
|
||||
register_after_tool_call_hook(sanitize_secrets)
|
||||
|
||||
# Test blocking
|
||||
blocked_context = ToolCallHookContext(
|
||||
tool_name="write_file",
|
||||
tool_input={"file_path": ".env"},
|
||||
tool=mock_tool,
|
||||
)
|
||||
|
||||
before_hooks = get_before_tool_call_hooks()
|
||||
blocked = False
|
||||
for hook in before_hooks:
|
||||
if hook(blocked_context) is False:
|
||||
blocked = True
|
||||
break
|
||||
|
||||
assert blocked is True
|
||||
|
||||
# Test sanitization
|
||||
sanitize_context = ToolCallHookContext(
|
||||
tool_name="read_file",
|
||||
tool_input={"file_path": "config.txt"},
|
||||
tool=mock_tool,
|
||||
tool_result="Content: SECRET_KEY=abc123",
|
||||
)
|
||||
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
result = sanitize_context.tool_result
|
||||
for hook in after_hooks:
|
||||
sanitize_context.tool_result = result
|
||||
modified = hook(sanitize_context)
|
||||
if modified is not None:
|
||||
result = modified
|
||||
|
||||
assert "SECRET_KEY=[REDACTED]" in result
|
||||
assert "abc123" not in result
|
||||
|
||||
|
||||
def test_unregister_before_hook(self):
|
||||
"""Test that before hooks can be unregistered."""
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
unregister_before_tool_call_hook(test_hook)
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
|
||||
def test_unregister_after_hook(self):
|
||||
"""Test that after hooks can be unregistered."""
|
||||
def test_hook(context):
|
||||
return None
|
||||
|
||||
register_after_tool_call_hook(test_hook)
|
||||
unregister_after_tool_call_hook(test_hook)
|
||||
hooks = get_after_tool_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
|
||||
def test_clear_all_tool_call_hooks(self):
|
||||
"""Test that all tool call hooks can be cleared."""
|
||||
def test_hook(context):
|
||||
pass
|
||||
|
||||
register_before_tool_call_hook(test_hook)
|
||||
register_after_tool_call_hook(test_hook)
|
||||
clear_all_tool_call_hooks()
|
||||
hooks = get_before_tool_call_hooks()
|
||||
assert len(hooks) == 0
|
||||
Reference in New Issue
Block a user