mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
feat: implement before and after LLM call hooks in CrewAgentExecutor (#3893)
- Added support for before and after LLM call hooks to allow modification of messages and responses during LLM interactions. - Introduced LLMCallHookContext to provide hooks with access to the executor state, enabling in-place modifications of messages. - Updated get_llm_response function to utilize the new hooks, ensuring that modifications persist across iterations. - Enhanced tests to verify the functionality of the hooks and their error handling capabilities, ensuring robust execution flow.
This commit is contained in:
@@ -2714,3 +2714,293 @@ def test_agent_without_apps_no_platform_tools():
|
||||
|
||||
tools = crew._prepare_tools(agent, task, [])
|
||||
assert tools == []
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_before_llm_call_hook_modifies_messages():
|
||||
"""Test that before_llm_call hooks can modify messages."""
|
||||
from crewai.utilities.llm_call_hooks import LLMCallHookContext, register_before_llm_call_hook
|
||||
|
||||
hook_called = False
|
||||
original_message_count = 0
|
||||
|
||||
def before_hook(context: LLMCallHookContext) -> None:
|
||||
nonlocal hook_called, original_message_count
|
||||
hook_called = True
|
||||
original_message_count = len(context.messages)
|
||||
context.messages.append({
|
||||
"role": "user",
|
||||
"content": "Additional context: This is a test modification."
|
||||
})
|
||||
|
||||
register_before_llm_call_hook(before_hook)
|
||||
|
||||
try:
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Say hello",
|
||||
expected_output="A greeting",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
assert hook_called, "before_llm_call hook should have been called"
|
||||
assert len(agent.agent_executor.messages) > original_message_count
|
||||
assert result is not None
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_after_llm_call_hook_modifies_messages_for_next_iteration():
|
||||
"""Test that after_llm_call hooks can modify messages for the next iteration."""
|
||||
from crewai.utilities.llm_call_hooks import LLMCallHookContext, register_after_llm_call_hook
|
||||
|
||||
hook_call_count = 0
|
||||
hook_iterations = []
|
||||
messages_added_in_iteration_0 = False
|
||||
test_message_content = "HOOK_ADDED_MESSAGE_FOR_NEXT_ITERATION"
|
||||
|
||||
def after_hook(context: LLMCallHookContext) -> str | None:
|
||||
nonlocal hook_call_count, hook_iterations, messages_added_in_iteration_0
|
||||
hook_call_count += 1
|
||||
current_iteration = context.iterations
|
||||
hook_iterations.append(current_iteration)
|
||||
|
||||
if current_iteration == 0:
|
||||
messages_before = len(context.messages)
|
||||
context.messages.append({
|
||||
"role": "user",
|
||||
"content": test_message_content
|
||||
})
|
||||
messages_added_in_iteration_0 = True
|
||||
assert len(context.messages) == messages_before + 1
|
||||
|
||||
return None
|
||||
|
||||
register_after_llm_call_hook(after_hook)
|
||||
|
||||
try:
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
allow_delegation=False,
|
||||
max_iter=3,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Count to 3, taking your time",
|
||||
expected_output="A count",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
assert hook_call_count > 0, "after_llm_call hook should have been called"
|
||||
assert messages_added_in_iteration_0, "Message should have been added in iteration 0"
|
||||
|
||||
executor_messages = agent.agent_executor.messages
|
||||
message_contents = [msg.get("content", "") for msg in executor_messages if isinstance(msg, dict)]
|
||||
assert any(test_message_content in content for content in message_contents), (
|
||||
f"Message added by hook in iteration 0 should be present in executor messages. "
|
||||
f"Messages: {message_contents}"
|
||||
)
|
||||
|
||||
assert len(executor_messages) > 2, "Executor should have more than initial messages"
|
||||
assert result is not None
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_after_llm_call_hook_modifies_messages():
|
||||
"""Test that after_llm_call hooks can modify messages for next iteration."""
|
||||
from crewai.utilities.llm_call_hooks import LLMCallHookContext, register_after_llm_call_hook
|
||||
|
||||
hook_called = False
|
||||
messages_before_hook = 0
|
||||
|
||||
def after_hook(context: LLMCallHookContext) -> str | None:
|
||||
nonlocal hook_called, messages_before_hook
|
||||
hook_called = True
|
||||
messages_before_hook = len(context.messages)
|
||||
context.messages.append({
|
||||
"role": "user",
|
||||
"content": "Remember: This is iteration 2 context."
|
||||
})
|
||||
return None # Don't modify response
|
||||
|
||||
register_after_llm_call_hook(after_hook)
|
||||
|
||||
try:
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
allow_delegation=False,
|
||||
max_iter=2,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Count to 2",
|
||||
expected_output="A count",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
assert hook_called, "after_llm_call hook should have been called"
|
||||
assert len(agent.agent_executor.messages) > messages_before_hook
|
||||
assert result is not None
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_hooks_with_crew():
|
||||
"""Test that LLM call hooks work with crew execution."""
|
||||
from crewai.utilities.llm_call_hooks import (
|
||||
LLMCallHookContext,
|
||||
register_after_llm_call_hook,
|
||||
register_before_llm_call_hook,
|
||||
)
|
||||
|
||||
before_hook_called = False
|
||||
after_hook_called = False
|
||||
|
||||
def before_hook(context: LLMCallHookContext) -> None:
|
||||
nonlocal before_hook_called
|
||||
before_hook_called = True
|
||||
assert context.executor is not None
|
||||
assert context.agent is not None
|
||||
assert context.task is not None
|
||||
context.messages.append({
|
||||
"role": "system",
|
||||
"content": "Additional system context from hook."
|
||||
})
|
||||
|
||||
def after_hook(context: LLMCallHookContext) -> str | None:
|
||||
nonlocal after_hook_called
|
||||
after_hook_called = True
|
||||
assert context.response is not None
|
||||
assert len(context.messages) > 0
|
||||
return None
|
||||
|
||||
register_before_llm_call_hook(before_hook)
|
||||
register_after_llm_call_hook(after_hook)
|
||||
|
||||
try:
|
||||
agent = Agent(
|
||||
role="Researcher",
|
||||
goal="Research topics",
|
||||
backstory="You are a researcher",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Research AI frameworks",
|
||||
expected_output="A research summary",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
assert before_hook_called, "before_llm_call hook should have been called"
|
||||
assert after_hook_called, "after_llm_call hook should have been called"
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_hooks_can_modify_executor_attributes():
|
||||
"""Test that hooks can access and modify executor attributes like tools."""
|
||||
from crewai.utilities.llm_call_hooks import LLMCallHookContext, register_before_llm_call_hook
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool
|
||||
def test_tool() -> str:
|
||||
"""A test tool."""
|
||||
return "test result"
|
||||
|
||||
hook_called = False
|
||||
original_tools_count = 0
|
||||
|
||||
def before_hook(context: LLMCallHookContext) -> None:
|
||||
nonlocal hook_called, original_tools_count
|
||||
hook_called = True
|
||||
original_tools_count = len(context.executor.tools)
|
||||
assert context.executor.max_iter > 0
|
||||
assert context.executor.iterations >= 0
|
||||
assert context.executor.tools is not None
|
||||
|
||||
register_before_llm_call_hook(before_hook)
|
||||
|
||||
try:
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
tools=[test_tool],
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Use the test tool",
|
||||
expected_output="Tool result",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
assert hook_called, "before_llm_call hook should have been called"
|
||||
assert original_tools_count >= 0
|
||||
assert result is not None
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_call_hooks_error_handling():
|
||||
"""Test that hook errors don't break execution."""
|
||||
from crewai.utilities.llm_call_hooks import LLMCallHookContext, register_before_llm_call_hook
|
||||
|
||||
hook_called = False
|
||||
|
||||
def error_hook(context: LLMCallHookContext) -> None:
|
||||
nonlocal hook_called
|
||||
hook_called = True
|
||||
raise ValueError("Test hook error")
|
||||
|
||||
register_before_llm_call_hook(error_hook)
|
||||
|
||||
try:
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Say hello",
|
||||
expected_output="A greeting",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
result = agent.execute_task(task)
|
||||
|
||||
assert hook_called, "before_llm_call hook should have been called"
|
||||
assert result is not None
|
||||
finally:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user