mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
- Updated Agent class to emit TaskFailedEvent instead of AgentExecutionErrorEvent when LLM calls are blocked. - Removed unnecessary LLMCallBlockedError handling from CrewAgentExecutor. - Enhanced test cases to ensure proper exception handling for blocked LLM calls. - Improved code clarity and consistency in event handling across agent execution.
582 lines
21 KiB
Python
582 lines
21 KiB
Python
"""Unit tests for LLM hooks functionality."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from unittest.mock import Mock
|
|
|
|
from crewai.hooks import (
|
|
LLMCallBlockedError,
|
|
clear_all_llm_call_hooks,
|
|
unregister_after_llm_call_hook,
|
|
unregister_before_llm_call_hook,
|
|
)
|
|
import pytest
|
|
|
|
from crewai.hooks.llm_hooks import (
|
|
LLMCallHookContext,
|
|
get_after_llm_call_hooks,
|
|
get_before_llm_call_hooks,
|
|
register_after_llm_call_hook,
|
|
register_before_llm_call_hook,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_executor():
|
|
"""Create a mock executor for testing."""
|
|
executor = Mock()
|
|
executor.messages = [{"role": "system", "content": "Test message"}]
|
|
executor.agent = Mock(role="Test Agent")
|
|
executor.task = Mock(description="Test Task")
|
|
executor.crew = Mock()
|
|
executor.llm = Mock()
|
|
executor.iterations = 0
|
|
return executor
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_hooks():
|
|
"""Clear global hooks before and after each test."""
|
|
# Import the private variables to clear them
|
|
from crewai.hooks import llm_hooks
|
|
|
|
# Store original hooks
|
|
original_before = llm_hooks._before_llm_call_hooks.copy()
|
|
original_after = llm_hooks._after_llm_call_hooks.copy()
|
|
|
|
# Clear hooks
|
|
llm_hooks._before_llm_call_hooks.clear()
|
|
llm_hooks._after_llm_call_hooks.clear()
|
|
|
|
yield
|
|
|
|
# Restore original hooks
|
|
llm_hooks._before_llm_call_hooks.clear()
|
|
llm_hooks._after_llm_call_hooks.clear()
|
|
llm_hooks._before_llm_call_hooks.extend(original_before)
|
|
llm_hooks._after_llm_call_hooks.extend(original_after)
|
|
|
|
|
|
class TestLLMCallHookContext:
|
|
"""Test LLMCallHookContext initialization and attributes."""
|
|
|
|
def test_context_initialization(self, mock_executor):
|
|
"""Test that context is initialized correctly with executor."""
|
|
context = LLMCallHookContext(executor=mock_executor)
|
|
|
|
assert context.executor == mock_executor
|
|
assert context.messages == mock_executor.messages
|
|
assert context.agent == mock_executor.agent
|
|
assert context.task == mock_executor.task
|
|
assert context.crew == mock_executor.crew
|
|
assert context.llm == mock_executor.llm
|
|
assert context.iterations == mock_executor.iterations
|
|
assert context.response is None
|
|
|
|
def test_context_with_response(self, mock_executor):
|
|
"""Test that context includes response when provided."""
|
|
test_response = "Test LLM response"
|
|
context = LLMCallHookContext(executor=mock_executor, response=test_response)
|
|
|
|
assert context.response == test_response
|
|
|
|
def test_messages_are_mutable_reference(self, mock_executor):
|
|
"""Test that modifying context.messages modifies executor.messages."""
|
|
context = LLMCallHookContext(executor=mock_executor)
|
|
|
|
# Add a message through context
|
|
new_message = {"role": "user", "content": "New message"}
|
|
context.messages.append(new_message)
|
|
|
|
# Check that executor.messages is also modified
|
|
assert new_message in mock_executor.messages
|
|
assert len(mock_executor.messages) == 2
|
|
|
|
def test_before_hook_returning_false_gracefully_finishes(self) -> None:
|
|
"""Test that when before_llm_call hook returns False, agent gracefully finishes."""
|
|
from crewai import Agent, Crew, Task
|
|
|
|
hook_called = {"before": False}
|
|
|
|
def blocking_hook(context: LLMCallHookContext) -> bool:
|
|
"""Hook that blocks all LLM calls."""
|
|
hook_called["before"] = True
|
|
return False
|
|
|
|
register_before_llm_call_hook(blocking_hook)
|
|
|
|
try:
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Answer questions",
|
|
backstory="You are a test agent",
|
|
verbose=True,
|
|
)
|
|
|
|
task = Task(
|
|
description="Say hello",
|
|
expected_output="A greeting",
|
|
agent=agent,
|
|
)
|
|
|
|
with pytest.raises(LLMCallBlockedError):
|
|
crew = Crew(agents=[agent], tasks=[task], verbose=True)
|
|
crew.kickoff()
|
|
finally:
|
|
unregister_before_llm_call_hook(blocking_hook)
|
|
|
|
def test_direct_llm_call_raises_blocked_error_when_hook_returns_false(self) -> None:
|
|
"""Test that direct LLM.call() raises LLMCallBlockedError when hook returns False."""
|
|
from crewai.hooks import LLMCallBlockedError
|
|
from crewai.llm import LLM
|
|
|
|
|
|
hook_called = {"before": False}
|
|
|
|
def blocking_hook(context: LLMCallHookContext) -> bool:
|
|
"""Hook that blocks all LLM calls."""
|
|
hook_called["before"] = True
|
|
return False
|
|
|
|
register_before_llm_call_hook(blocking_hook)
|
|
|
|
try:
|
|
llm = LLM(model="gpt-4o-mini")
|
|
|
|
with pytest.raises(LLMCallBlockedError) as exc_info:
|
|
llm.call([{"role": "user", "content": "Say hello"}])
|
|
|
|
assert hook_called["before"] is True, "Before hook should have been called"
|
|
|
|
assert "blocked" in str(exc_info.value).lower()
|
|
|
|
finally:
|
|
unregister_before_llm_call_hook(blocking_hook)
|
|
|
|
def test_raises_with_llm_call_blocked_exception(self) -> None:
|
|
"""Test that the LLM call raises an exception when the hook raises an exception."""
|
|
from crewai.hooks import LLMCallBlockedError
|
|
from crewai.llm import LLM
|
|
|
|
def blocking_hook(context: LLMCallHookContext) -> bool:
|
|
raise LLMCallBlockedError("llm call blocked")
|
|
register_before_llm_call_hook(blocking_hook)
|
|
|
|
try:
|
|
llm = LLM(model="gpt-4o-mini")
|
|
with pytest.raises(LLMCallBlockedError) as exc_info:
|
|
llm.call([{"role": "user", "content": "Say hello"}])
|
|
assert "blocked" in str(exc_info.value).lower()
|
|
finally:
|
|
unregister_before_llm_call_hook(blocking_hook)
|
|
|
|
|
|
|
|
|
|
class TestBeforeLLMCallHooks:
|
|
"""Test before_llm_call hook registration and execution."""
|
|
|
|
def test_register_before_hook(self):
|
|
"""Test that before hooks are registered correctly."""
|
|
|
|
def test_hook(context):
|
|
pass
|
|
|
|
register_before_llm_call_hook(test_hook)
|
|
hooks = get_before_llm_call_hooks()
|
|
|
|
assert len(hooks) == 1
|
|
assert hooks[0] == test_hook
|
|
|
|
def test_multiple_before_hooks(self):
|
|
"""Test that multiple before hooks can be registered."""
|
|
|
|
def hook1(context):
|
|
pass
|
|
|
|
def hook2(context):
|
|
pass
|
|
|
|
register_before_llm_call_hook(hook1)
|
|
register_before_llm_call_hook(hook2)
|
|
hooks = get_before_llm_call_hooks()
|
|
|
|
assert len(hooks) == 2
|
|
assert hook1 in hooks
|
|
assert hook2 in hooks
|
|
|
|
def test_before_hook_can_modify_messages(self, mock_executor):
|
|
"""Test that before hooks can modify messages in-place."""
|
|
|
|
def add_message_hook(context):
|
|
context.messages.append({"role": "system", "content": "Added by hook"})
|
|
|
|
context = LLMCallHookContext(executor=mock_executor)
|
|
add_message_hook(context)
|
|
|
|
assert len(context.messages) == 2
|
|
assert context.messages[1]["content"] == "Added by hook"
|
|
|
|
def test_get_before_hooks_returns_copy(self):
|
|
"""Test that get_before_llm_call_hooks returns a copy."""
|
|
|
|
def test_hook(context):
|
|
pass
|
|
|
|
register_before_llm_call_hook(test_hook)
|
|
hooks1 = get_before_llm_call_hooks()
|
|
hooks2 = get_before_llm_call_hooks()
|
|
|
|
# They should be equal but not the same object
|
|
assert hooks1 == hooks2
|
|
assert hooks1 is not hooks2
|
|
|
|
|
|
class TestAfterLLMCallHooks:
|
|
"""Test after_llm_call hook registration and execution."""
|
|
|
|
def test_register_after_hook(self):
|
|
"""Test that after hooks are registered correctly."""
|
|
|
|
def test_hook(context):
|
|
return None
|
|
|
|
register_after_llm_call_hook(test_hook)
|
|
hooks = get_after_llm_call_hooks()
|
|
|
|
assert len(hooks) == 1
|
|
assert hooks[0] == test_hook
|
|
|
|
def test_multiple_after_hooks(self):
|
|
"""Test that multiple after hooks can be registered."""
|
|
|
|
def hook1(context):
|
|
return None
|
|
|
|
def hook2(context):
|
|
return None
|
|
|
|
register_after_llm_call_hook(hook1)
|
|
register_after_llm_call_hook(hook2)
|
|
hooks = get_after_llm_call_hooks()
|
|
|
|
assert len(hooks) == 2
|
|
assert hook1 in hooks
|
|
assert hook2 in hooks
|
|
|
|
def test_after_hook_can_modify_response(self, mock_executor):
|
|
"""Test that after hooks can modify the response."""
|
|
original_response = "Original response"
|
|
|
|
def modify_response_hook(context):
|
|
if context.response:
|
|
return context.response.replace("Original", "Modified")
|
|
return None
|
|
|
|
context = LLMCallHookContext(executor=mock_executor, response=original_response)
|
|
modified = modify_response_hook(context)
|
|
|
|
assert modified == "Modified response"
|
|
|
|
def test_after_hook_returns_none_keeps_original(self, mock_executor):
|
|
"""Test that returning None keeps the original response."""
|
|
original_response = "Original response"
|
|
|
|
def no_change_hook(context):
|
|
return None
|
|
|
|
context = LLMCallHookContext(executor=mock_executor, response=original_response)
|
|
result = no_change_hook(context)
|
|
|
|
assert result is None
|
|
assert context.response == original_response
|
|
|
|
def test_get_after_hooks_returns_copy(self):
|
|
"""Test that get_after_llm_call_hooks returns a copy."""
|
|
|
|
def test_hook(context):
|
|
return None
|
|
|
|
register_after_llm_call_hook(test_hook)
|
|
hooks1 = get_after_llm_call_hooks()
|
|
hooks2 = get_after_llm_call_hooks()
|
|
|
|
# They should be equal but not the same object
|
|
assert hooks1 == hooks2
|
|
assert hooks1 is not hooks2
|
|
|
|
|
|
class TestLLMHooksIntegration:
|
|
"""Test integration scenarios with multiple hooks."""
|
|
|
|
def test_multiple_before_hooks_execute_in_order(self, mock_executor):
|
|
"""Test that multiple before hooks execute in registration order."""
|
|
execution_order = []
|
|
|
|
def hook1(context):
|
|
execution_order.append(1)
|
|
|
|
def hook2(context):
|
|
execution_order.append(2)
|
|
|
|
def hook3(context):
|
|
execution_order.append(3)
|
|
|
|
register_before_llm_call_hook(hook1)
|
|
register_before_llm_call_hook(hook2)
|
|
register_before_llm_call_hook(hook3)
|
|
|
|
context = LLMCallHookContext(executor=mock_executor)
|
|
hooks = get_before_llm_call_hooks()
|
|
|
|
for hook in hooks:
|
|
hook(context)
|
|
|
|
assert execution_order == [1, 2, 3]
|
|
|
|
def test_multiple_after_hooks_chain_modifications(self, mock_executor):
|
|
"""Test that multiple after hooks can chain modifications."""
|
|
|
|
def hook1(context):
|
|
if context.response:
|
|
return context.response + " [hook1]"
|
|
return None
|
|
|
|
def hook2(context):
|
|
if context.response:
|
|
return context.response + " [hook2]"
|
|
return None
|
|
|
|
register_after_llm_call_hook(hook1)
|
|
register_after_llm_call_hook(hook2)
|
|
|
|
context = LLMCallHookContext(executor=mock_executor, response="Original")
|
|
hooks = get_after_llm_call_hooks()
|
|
|
|
# Simulate chaining (how it would be used in practice)
|
|
result = context.response
|
|
for hook in hooks:
|
|
# Update context for next hook
|
|
context.response = result
|
|
modified = hook(context)
|
|
if modified is not None:
|
|
result = modified
|
|
|
|
assert result == "Original [hook1] [hook2]"
|
|
|
|
def test_unregister_before_hook(self):
|
|
"""Test that before hooks can be unregistered."""
|
|
def test_hook(context):
|
|
pass
|
|
|
|
register_before_llm_call_hook(test_hook)
|
|
unregister_before_llm_call_hook(test_hook)
|
|
hooks = get_before_llm_call_hooks()
|
|
assert len(hooks) == 0
|
|
|
|
def test_unregister_after_hook(self):
|
|
"""Test that after hooks can be unregistered."""
|
|
def test_hook(context):
|
|
return None
|
|
|
|
register_after_llm_call_hook(test_hook)
|
|
unregister_after_llm_call_hook(test_hook)
|
|
hooks = get_after_llm_call_hooks()
|
|
assert len(hooks) == 0
|
|
|
|
def test_clear_all_llm_call_hooks(self):
|
|
"""Test that all llm call hooks can be cleared."""
|
|
def test_hook(context):
|
|
pass
|
|
|
|
register_before_llm_call_hook(test_hook)
|
|
register_after_llm_call_hook(test_hook)
|
|
clear_all_llm_call_hooks()
|
|
hooks = get_before_llm_call_hooks()
|
|
assert len(hooks) == 0
|
|
|
|
@pytest.mark.vcr()
|
|
def test_lite_agent_hooks_integration_with_real_llm(self):
|
|
"""Test that LiteAgent executes before/after LLM call hooks and prints messages correctly."""
|
|
import os
|
|
from crewai.lite_agent import LiteAgent
|
|
|
|
# Skip if no API key available
|
|
if not os.environ.get("OPENAI_API_KEY"):
|
|
pytest.skip("OPENAI_API_KEY not set - skipping real LLM test")
|
|
|
|
# Track hook invocations
|
|
hook_calls = {"before": [], "after": []}
|
|
|
|
def before_llm_call_hook(context: LLMCallHookContext) -> bool:
|
|
"""Log and verify before hook execution."""
|
|
print(f"\n[BEFORE HOOK] Agent: {context.agent.role if context.agent else 'None'}")
|
|
print(f"[BEFORE HOOK] Iterations: {context.iterations}")
|
|
print(f"[BEFORE HOOK] Message count: {len(context.messages)}")
|
|
print(f"[BEFORE HOOK] Messages: {context.messages}")
|
|
|
|
# Track the call
|
|
hook_calls["before"].append({
|
|
"iterations": context.iterations,
|
|
"message_count": len(context.messages),
|
|
"has_task": context.task is not None,
|
|
"has_crew": context.crew is not None,
|
|
})
|
|
|
|
return True # Allow execution
|
|
|
|
def after_llm_call_hook(context: LLMCallHookContext) -> str | None:
|
|
"""Log and verify after hook execution."""
|
|
print(f"\n[AFTER HOOK] Agent: {context.agent.role if context.agent else 'None'}")
|
|
print(f"[AFTER HOOK] Iterations: {context.iterations}")
|
|
print(f"[AFTER HOOK] Response: {context.response[:100] if context.response else 'None'}...")
|
|
print(f"[AFTER HOOK] Final message count: {len(context.messages)}")
|
|
|
|
# Track the call
|
|
hook_calls["after"].append({
|
|
"iterations": context.iterations,
|
|
"has_response": context.response is not None,
|
|
"response_length": len(context.response) if context.response else 0,
|
|
})
|
|
|
|
# Optionally modify response
|
|
if context.response:
|
|
return f"[HOOKED] {context.response}"
|
|
return None
|
|
|
|
# Register hooks
|
|
register_before_llm_call_hook(before_llm_call_hook)
|
|
register_after_llm_call_hook(after_llm_call_hook)
|
|
|
|
try:
|
|
# Create LiteAgent
|
|
lite_agent = LiteAgent(
|
|
role="Test Assistant",
|
|
goal="Answer questions briefly",
|
|
backstory="You are a helpful test assistant",
|
|
verbose=True,
|
|
)
|
|
|
|
# Verify hooks are loaded
|
|
assert len(lite_agent.before_llm_call_hooks) > 0, "Before hooks not loaded"
|
|
assert len(lite_agent.after_llm_call_hooks) > 0, "After hooks not loaded"
|
|
|
|
# Execute with a simple prompt
|
|
result = lite_agent.kickoff("Say 'Hello World' and nothing else")
|
|
|
|
|
|
# Verify hooks were called
|
|
assert len(hook_calls["before"]) > 0, "Before hook was never called"
|
|
assert len(hook_calls["after"]) > 0, "After hook was never called"
|
|
|
|
# Verify context had correct attributes for LiteAgent (used in flows)
|
|
# LiteAgent doesn't have task/crew context, unlike agents in CrewBase
|
|
before_call = hook_calls["before"][0]
|
|
assert before_call["has_task"] is False, "Task should be None for LiteAgent in flows"
|
|
assert before_call["has_crew"] is False, "Crew should be None for LiteAgent in flows"
|
|
assert before_call["message_count"] > 0, "Should have messages"
|
|
|
|
# Verify after hook received response
|
|
after_call = hook_calls["after"][0]
|
|
assert after_call["has_response"] is True, "After hook should have response"
|
|
assert after_call["response_length"] > 0, "Response should not be empty"
|
|
|
|
# Verify response was modified by after hook
|
|
# Note: The hook modifies the raw LLM response, but LiteAgent then parses it
|
|
# to extract the "Final Answer" portion. We check the messages to see the modification.
|
|
assert len(result.messages) > 2, "Should have assistant message in messages"
|
|
last_message = result.messages[-1]
|
|
assert last_message["role"] == "assistant", "Last message should be from assistant"
|
|
assert "[HOOKED]" in last_message["content"], "Hook should have modified the assistant message"
|
|
|
|
|
|
finally:
|
|
# Clean up hooks
|
|
unregister_before_llm_call_hook(before_llm_call_hook)
|
|
unregister_after_llm_call_hook(after_llm_call_hook)
|
|
|
|
@pytest.mark.vcr()
|
|
def test_direct_llm_call_hooks_integration(self):
|
|
"""Test that hooks work for direct llm.call() without agents."""
|
|
import os
|
|
from crewai.llm import LLM
|
|
|
|
# Skip if no API key available
|
|
if not os.environ.get("OPENAI_API_KEY"):
|
|
pytest.skip("OPENAI_API_KEY not set - skipping real LLM test")
|
|
|
|
# Track hook invocations
|
|
hook_calls = {"before": [], "after": []}
|
|
|
|
def before_hook(context: LLMCallHookContext) -> bool:
|
|
"""Log and verify before hook execution."""
|
|
print(f"\n[BEFORE HOOK] Agent: {context.agent}")
|
|
print(f"[BEFORE HOOK] Task: {context.task}")
|
|
print(f"[BEFORE HOOK] Crew: {context.crew}")
|
|
print(f"[BEFORE HOOK] LLM: {context.llm}")
|
|
print(f"[BEFORE HOOK] Iterations: {context.iterations}")
|
|
print(f"[BEFORE HOOK] Message count: {len(context.messages)}")
|
|
|
|
# Track the call
|
|
hook_calls["before"].append({
|
|
"agent": context.agent,
|
|
"task": context.task,
|
|
"crew": context.crew,
|
|
"llm": context.llm is not None,
|
|
"message_count": len(context.messages),
|
|
})
|
|
|
|
return True # Allow execution
|
|
|
|
def after_hook(context: LLMCallHookContext) -> str | None:
|
|
"""Log and verify after hook execution."""
|
|
print(f"\n[AFTER HOOK] Agent: {context.agent}")
|
|
print(f"[AFTER HOOK] Response: {context.response[:100] if context.response else 'None'}...")
|
|
|
|
# Track the call
|
|
hook_calls["after"].append({
|
|
"has_response": context.response is not None,
|
|
"response_length": len(context.response) if context.response else 0,
|
|
})
|
|
|
|
# Modify response
|
|
if context.response:
|
|
return f"[HOOKED] {context.response}"
|
|
return None
|
|
|
|
# Register hooks
|
|
register_before_llm_call_hook(before_hook)
|
|
register_after_llm_call_hook(after_hook)
|
|
|
|
try:
|
|
# Create LLM and make direct call
|
|
llm = LLM(model="gpt-4o-mini")
|
|
result = llm.call([{"role": "user", "content": "Say hello"}])
|
|
|
|
print(f"\n[TEST] Final result: {result}")
|
|
|
|
# Verify hooks were called
|
|
assert len(hook_calls["before"]) > 0, "Before hook was never called"
|
|
assert len(hook_calls["after"]) > 0, "After hook was never called"
|
|
|
|
# Verify context had correct attributes for direct LLM calls
|
|
before_call = hook_calls["before"][0]
|
|
assert before_call["agent"] is None, "Agent should be None for direct LLM calls"
|
|
assert before_call["task"] is None, "Task should be None for direct LLM calls"
|
|
assert before_call["crew"] is None, "Crew should be None for direct LLM calls"
|
|
assert before_call["llm"] is True, "LLM should be present"
|
|
assert before_call["message_count"] > 0, "Should have messages"
|
|
|
|
# Verify after hook received response
|
|
after_call = hook_calls["after"][0]
|
|
assert after_call["has_response"] is True, "After hook should have response"
|
|
assert after_call["response_length"] > 0, "Response should not be empty"
|
|
|
|
# Verify response was modified by after hook
|
|
assert "[HOOKED]" in result, "Response should be modified by after hook"
|
|
|
|
finally:
|
|
# Clean up hooks
|
|
unregister_before_llm_call_hook(before_hook)
|
|
unregister_after_llm_call_hook(after_hook)
|