Files
crewAI/lib/crewai/tests/hooks/test_llm_hooks.py
Lorenze Jay 528d812263 Lorenze/feat hooks (#3902)
* feat: implement LLM call hooks and enhance agent execution context

- Introduced LLM call hooks to allow modification of messages and responses during LLM interactions.
- Added support for before and after hooks in the CrewAgentExecutor, enabling dynamic adjustments to the execution flow.
- Created LLMCallHookContext for comprehensive access to the executor state, facilitating in-place modifications.
- Added validation for hook callables to ensure proper functionality.
- Enhanced tests for LLM hooks and tool hooks to verify their behavior and error handling capabilities.
- Updated LiteAgent and CrewAgentExecutor to accommodate the new crew context in their execution processes.

* feat: implement LLM call hooks and enhance agent execution context

- Introduced LLM call hooks to allow modification of messages and responses during LLM interactions.
- Added support for before and after hooks in the CrewAgentExecutor, enabling dynamic adjustments to the execution flow.
- Created LLMCallHookContext for comprehensive access to the executor state, facilitating in-place modifications.
- Added validation for hook callables to ensure proper functionality.
- Enhanced tests for LLM hooks and tool hooks to verify their behavior and error handling capabilities.
- Updated LiteAgent and CrewAgentExecutor to accommodate the new crew context in their execution processes.

* fix verbose

* feat: introduce crew-scoped hook decorators and refactor hook registration

- Added decorators for before and after LLM and tool calls to enhance flexibility in modifying execution behavior.
- Implemented a centralized hook registration mechanism within CrewBase to automatically register crew-scoped hooks.
- Removed the obsolete base.py file as its functionality has been integrated into the new decorators and registration system.
- Enhanced tests for the new hook decorators to ensure proper registration and execution flow.
- Updated existing hook handling to accommodate the new decorator-based approach, improving code organization and maintainability.

* feat: enhance hook management with clear and unregister functions

- Introduced functions to unregister specific before and after hooks for both LLM and tool calls, improving flexibility in hook management.
- Added clear functions to remove all registered hooks of each type, facilitating easier state management and cleanup.
- Implemented a convenience function to clear all global hooks in one call, streamlining the process for testing and execution context resets.
- Enhanced tests to verify the functionality of unregistering and clearing hooks, ensuring robust behavior in various scenarios.

* refactor: enhance hook type management for LLM and tool hooks

- Updated hook type definitions to use generic protocols for better type safety and flexibility.
- Replaced Callable type annotations with specific BeforeLLMCallHookType and AfterLLMCallHookType for clarity.
- Improved the registration and retrieval functions for before and after hooks to align with the new type definitions.
- Enhanced the setup functions to handle hook execution results, allowing for blocking of LLM calls based on hook logic.
- Updated related tests to ensure proper functionality and type adherence across the hook management system.

* feat: add execution and tool hooks documentation

- Introduced new documentation for execution hooks, LLM call hooks, and tool call hooks to provide comprehensive guidance on their usage and implementation in CrewAI.
- Updated existing documentation to include references to the new hooks, enhancing the learning resources available for users.
- Ensured consistency across multiple languages (English, Portuguese, Korean) for the new documentation, improving accessibility for a wider audience.
- Added examples and troubleshooting sections to assist users in effectively utilizing hooks for agent operations.

---------

Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
2025-11-13 10:11:50 -08:00

312 lines
9.7 KiB
Python

"""Unit tests for LLM hooks functionality."""
from __future__ import annotations
from unittest.mock import Mock
from crewai.hooks import clear_all_llm_call_hooks, unregister_after_llm_call_hook, unregister_before_llm_call_hook
import pytest
from crewai.hooks.llm_hooks import (
LLMCallHookContext,
get_after_llm_call_hooks,
get_before_llm_call_hooks,
register_after_llm_call_hook,
register_before_llm_call_hook,
)
@pytest.fixture
def mock_executor():
"""Create a mock executor for testing."""
executor = Mock()
executor.messages = [{"role": "system", "content": "Test message"}]
executor.agent = Mock(role="Test Agent")
executor.task = Mock(description="Test Task")
executor.crew = Mock()
executor.llm = Mock()
executor.iterations = 0
return executor
@pytest.fixture(autouse=True)
def clear_hooks():
"""Clear global hooks before and after each test."""
# Import the private variables to clear them
from crewai.hooks import llm_hooks
# Store original hooks
original_before = llm_hooks._before_llm_call_hooks.copy()
original_after = llm_hooks._after_llm_call_hooks.copy()
# Clear hooks
llm_hooks._before_llm_call_hooks.clear()
llm_hooks._after_llm_call_hooks.clear()
yield
# Restore original hooks
llm_hooks._before_llm_call_hooks.clear()
llm_hooks._after_llm_call_hooks.clear()
llm_hooks._before_llm_call_hooks.extend(original_before)
llm_hooks._after_llm_call_hooks.extend(original_after)
class TestLLMCallHookContext:
"""Test LLMCallHookContext initialization and attributes."""
def test_context_initialization(self, mock_executor):
"""Test that context is initialized correctly with executor."""
context = LLMCallHookContext(executor=mock_executor)
assert context.executor == mock_executor
assert context.messages == mock_executor.messages
assert context.agent == mock_executor.agent
assert context.task == mock_executor.task
assert context.crew == mock_executor.crew
assert context.llm == mock_executor.llm
assert context.iterations == mock_executor.iterations
assert context.response is None
def test_context_with_response(self, mock_executor):
"""Test that context includes response when provided."""
test_response = "Test LLM response"
context = LLMCallHookContext(executor=mock_executor, response=test_response)
assert context.response == test_response
def test_messages_are_mutable_reference(self, mock_executor):
"""Test that modifying context.messages modifies executor.messages."""
context = LLMCallHookContext(executor=mock_executor)
# Add a message through context
new_message = {"role": "user", "content": "New message"}
context.messages.append(new_message)
# Check that executor.messages is also modified
assert new_message in mock_executor.messages
assert len(mock_executor.messages) == 2
class TestBeforeLLMCallHooks:
"""Test before_llm_call hook registration and execution."""
def test_register_before_hook(self):
"""Test that before hooks are registered correctly."""
def test_hook(context):
pass
register_before_llm_call_hook(test_hook)
hooks = get_before_llm_call_hooks()
assert len(hooks) == 1
assert hooks[0] == test_hook
def test_multiple_before_hooks(self):
"""Test that multiple before hooks can be registered."""
def hook1(context):
pass
def hook2(context):
pass
register_before_llm_call_hook(hook1)
register_before_llm_call_hook(hook2)
hooks = get_before_llm_call_hooks()
assert len(hooks) == 2
assert hook1 in hooks
assert hook2 in hooks
def test_before_hook_can_modify_messages(self, mock_executor):
"""Test that before hooks can modify messages in-place."""
def add_message_hook(context):
context.messages.append({"role": "system", "content": "Added by hook"})
context = LLMCallHookContext(executor=mock_executor)
add_message_hook(context)
assert len(context.messages) == 2
assert context.messages[1]["content"] == "Added by hook"
def test_get_before_hooks_returns_copy(self):
"""Test that get_before_llm_call_hooks returns a copy."""
def test_hook(context):
pass
register_before_llm_call_hook(test_hook)
hooks1 = get_before_llm_call_hooks()
hooks2 = get_before_llm_call_hooks()
# They should be equal but not the same object
assert hooks1 == hooks2
assert hooks1 is not hooks2
class TestAfterLLMCallHooks:
"""Test after_llm_call hook registration and execution."""
def test_register_after_hook(self):
"""Test that after hooks are registered correctly."""
def test_hook(context):
return None
register_after_llm_call_hook(test_hook)
hooks = get_after_llm_call_hooks()
assert len(hooks) == 1
assert hooks[0] == test_hook
def test_multiple_after_hooks(self):
"""Test that multiple after hooks can be registered."""
def hook1(context):
return None
def hook2(context):
return None
register_after_llm_call_hook(hook1)
register_after_llm_call_hook(hook2)
hooks = get_after_llm_call_hooks()
assert len(hooks) == 2
assert hook1 in hooks
assert hook2 in hooks
def test_after_hook_can_modify_response(self, mock_executor):
"""Test that after hooks can modify the response."""
original_response = "Original response"
def modify_response_hook(context):
if context.response:
return context.response.replace("Original", "Modified")
return None
context = LLMCallHookContext(executor=mock_executor, response=original_response)
modified = modify_response_hook(context)
assert modified == "Modified response"
def test_after_hook_returns_none_keeps_original(self, mock_executor):
"""Test that returning None keeps the original response."""
original_response = "Original response"
def no_change_hook(context):
return None
context = LLMCallHookContext(executor=mock_executor, response=original_response)
result = no_change_hook(context)
assert result is None
assert context.response == original_response
def test_get_after_hooks_returns_copy(self):
"""Test that get_after_llm_call_hooks returns a copy."""
def test_hook(context):
return None
register_after_llm_call_hook(test_hook)
hooks1 = get_after_llm_call_hooks()
hooks2 = get_after_llm_call_hooks()
# They should be equal but not the same object
assert hooks1 == hooks2
assert hooks1 is not hooks2
class TestLLMHooksIntegration:
"""Test integration scenarios with multiple hooks."""
def test_multiple_before_hooks_execute_in_order(self, mock_executor):
"""Test that multiple before hooks execute in registration order."""
execution_order = []
def hook1(context):
execution_order.append(1)
def hook2(context):
execution_order.append(2)
def hook3(context):
execution_order.append(3)
register_before_llm_call_hook(hook1)
register_before_llm_call_hook(hook2)
register_before_llm_call_hook(hook3)
context = LLMCallHookContext(executor=mock_executor)
hooks = get_before_llm_call_hooks()
for hook in hooks:
hook(context)
assert execution_order == [1, 2, 3]
def test_multiple_after_hooks_chain_modifications(self, mock_executor):
"""Test that multiple after hooks can chain modifications."""
def hook1(context):
if context.response:
return context.response + " [hook1]"
return None
def hook2(context):
if context.response:
return context.response + " [hook2]"
return None
register_after_llm_call_hook(hook1)
register_after_llm_call_hook(hook2)
context = LLMCallHookContext(executor=mock_executor, response="Original")
hooks = get_after_llm_call_hooks()
# Simulate chaining (how it would be used in practice)
result = context.response
for hook in hooks:
# Update context for next hook
context.response = result
modified = hook(context)
if modified is not None:
result = modified
assert result == "Original [hook1] [hook2]"
def test_unregister_before_hook(self):
"""Test that before hooks can be unregistered."""
def test_hook(context):
pass
register_before_llm_call_hook(test_hook)
unregister_before_llm_call_hook(test_hook)
hooks = get_before_llm_call_hooks()
assert len(hooks) == 0
def test_unregister_after_hook(self):
"""Test that after hooks can be unregistered."""
def test_hook(context):
return None
register_after_llm_call_hook(test_hook)
unregister_after_llm_call_hook(test_hook)
hooks = get_after_llm_call_hooks()
assert len(hooks) == 0
def test_clear_all_llm_call_hooks(self):
"""Test that all llm call hooks can be cleared."""
def test_hook(context):
pass
register_before_llm_call_hook(test_hook)
register_after_llm_call_hook(test_hook)
clear_all_llm_call_hooks()
hooks = get_before_llm_call_hooks()
assert len(hooks) == 0