ensure hitl works

This commit is contained in:
lorenzejay
2025-12-29 11:24:46 -08:00
parent 2e429b4ef6
commit 480708de2f
5 changed files with 200 additions and 199 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING
from crewai.agents.parser import AgentFinish
from crewai.events.event_listener import event_listener
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
@@ -29,7 +30,7 @@ class CrewAgentExecutorMixin:
_i18n: I18N
_printer: Printer = Printer()
def _create_short_term_memory(self, output) -> None:
def _create_short_term_memory(self, output: AgentFinish) -> None:
"""Create and save a short-term memory item if conditions are met."""
if (
self.crew
@@ -53,7 +54,7 @@ class CrewAgentExecutorMixin:
"error", f"Failed to add to short term memory: {e}"
)
def _create_external_memory(self, output) -> None:
def _create_external_memory(self, output: AgentFinish) -> None:
"""Create and save a external-term memory item if conditions are met."""
if (
self.crew
@@ -75,7 +76,7 @@ class CrewAgentExecutorMixin:
"error", f"Failed to add to external memory: {e}"
)
def _create_long_term_memory(self, output) -> None:
def _create_long_term_memory(self, output: AgentFinish) -> None:
"""Create and save long-term and entity memory items based on evaluation."""
if (
self.crew
@@ -136,40 +137,50 @@ class CrewAgentExecutorMixin:
)
def _ask_human_input(self, final_answer: str) -> str:
"""Prompt human input with mode-appropriate messaging."""
event_listener.formatter.pause_live_updates()
try:
self._printer.print(
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
)
"""Prompt human input with mode-appropriate messaging.
Note: The final answer is already displayed via the AgentLogsExecutionEvent
panel, so we only show the feedback prompt here.
"""
from rich.panel import Panel
from rich.text import Text
formatter = event_listener.formatter
formatter.pause_live_updates()
try:
# Training mode prompt (single iteration)
if self.crew and getattr(self.crew, "_train", False):
prompt = (
"\n\n=====\n"
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
prompt_text = (
"TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
"This will be used to train better versions of the agent.\n"
"Please provide detailed feedback about the result quality and reasoning process.\n"
"=====\n"
"Please provide detailed feedback about the result quality and reasoning process."
)
title = "🎓 Training Feedback Required"
# Regular human-in-the-loop prompt (multiple iterations)
else:
prompt = (
"\n\n=====\n"
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
"Please follow these guidelines:\n"
" - If you are happy with the result, simply hit Enter without typing anything.\n"
" - Otherwise, provide specific improvement requests.\n"
" - You can provide multiple rounds of feedback until satisfied.\n"
"=====\n"
prompt_text = (
"Provide feedback on the Final Result above.\n\n"
"• If you are happy with the result, simply hit Enter without typing anything.\n"
"• Otherwise, provide specific improvement requests.\n"
"• You can provide multiple rounds of feedback until satisfied."
)
title = "💬 Human Feedback Required"
content = Text()
content.append(prompt_text, style="yellow")
prompt_panel = Panel(
content,
title=title,
border_style="yellow",
padding=(1, 2),
)
formatter.console.print(prompt_panel)
self._printer.print(content=prompt, color="bold_yellow")
response = input()
if response.strip() != "":
self._printer.print(
content="\nProcessing your feedback...", color="cyan"
)
formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
return response
finally:
event_listener.formatter.resume_live_updates()
formatter.resume_live_updates()

View File

@@ -541,7 +541,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.agent is None:
raise ValueError("Agent cannot be None")
crewai_event_bus.emit(
future = crewai_event_bus.emit(
self.agent,
AgentLogsExecutionEvent(
agent_role=self.agent.role,
@@ -551,6 +551,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
),
)
if future is not None:
try:
future.result(timeout=5.0)
except Exception:
pass
def _handle_crew_training_output(
self, result: AgentFinish, human_feedback: str | None = None
) -> None:

View File

@@ -98,6 +98,24 @@ To enable tracing, do any one of these:
return
self.console.print(*args, **kwargs)
def pause_live_updates(self) -> None:
"""Pause Live session updates to allow for human input without interference.
This stops any active streaming Live session to prevent console refresh
interference during HITL (Human-in-the-Loop) user input.
"""
if self._streaming_live:
self._streaming_live.stop()
self._streaming_live = None
def resume_live_updates(self) -> None:
"""Resume Live session updates after human input is complete.
New streaming sessions will be created on-demand when needed.
This method exists for API compatibility with HITL callers.
"""
pass
def print_panel(
self, content: Text, title: str, style: str = "blue", is_flow: bool = False
) -> None:

View File

@@ -7,22 +7,19 @@ from crewai.events.event_listener import event_listener
class TestFlowHumanInputIntegration:
"""Test integration between Flow execution and human input functionality."""
def test_console_formatter_pause_resume_methods(self):
"""Test that ConsoleFormatter pause/resume methods work correctly."""
def test_console_formatter_pause_resume_methods_exist(self):
"""Test that ConsoleFormatter pause/resume methods exist and are callable."""
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
# Methods should exist and be callable
assert hasattr(formatter, "pause_live_updates")
assert hasattr(formatter, "resume_live_updates")
assert callable(formatter.pause_live_updates)
assert callable(formatter.resume_live_updates)
try:
formatter._live_paused = False
formatter.pause_live_updates()
assert formatter._live_paused
formatter.resume_live_updates()
assert not formatter._live_paused
finally:
formatter._live_paused = original_paused_state
# Should not raise
formatter.pause_live_updates()
formatter.resume_live_updates()
@patch("builtins.input", return_value="")
def test_human_input_pauses_flow_updates(self, mock_input):
@@ -38,23 +35,16 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
):
result = executor._ask_human_input("Test result")
try:
formatter._live_paused = False
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
):
result = executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
mock_input.assert_called_once()
assert result == ""
finally:
formatter._live_paused = original_paused_state
mock_pause.assert_called_once()
mock_resume.assert_called_once()
mock_input.assert_called_once()
assert result == ""
@patch("builtins.input", side_effect=["feedback", ""])
def test_multiple_human_input_rounds(self, mock_input):
@@ -70,53 +60,46 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
pause_calls = []
resume_calls = []
try:
pause_calls = []
resume_calls = []
def track_pause():
pause_calls.append(True)
def track_pause():
pause_calls.append(True)
def track_resume():
resume_calls.append(True)
def track_resume():
resume_calls.append(True)
with (
patch.object(formatter, "pause_live_updates", side_effect=track_pause),
patch.object(
formatter, "resume_live_updates", side_effect=track_resume
),
):
result1 = executor._ask_human_input("Test result 1")
assert result1 == "feedback"
with (
patch.object(formatter, "pause_live_updates", side_effect=track_pause),
patch.object(
formatter, "resume_live_updates", side_effect=track_resume
),
):
result1 = executor._ask_human_input("Test result 1")
assert result1 == "feedback"
result2 = executor._ask_human_input("Test result 2")
assert result2 == ""
result2 = executor._ask_human_input("Test result 2")
assert result2 == ""
assert len(pause_calls) == 2
assert len(resume_calls) == 2
finally:
formatter._live_paused = original_paused_state
assert len(pause_calls) == 2
assert len(resume_calls) == 2
def test_pause_resume_with_no_live_session(self):
"""Test pause/resume methods handle case when no Live session exists."""
formatter = event_listener.formatter
original_live = formatter._live
original_paused_state = formatter._live_paused
original_streaming_live = formatter._streaming_live
try:
formatter._live = None
formatter._live_paused = False
formatter._streaming_live = None
# Should not raise when no session exists
formatter.pause_live_updates()
formatter.resume_live_updates()
assert not formatter._live_paused
assert formatter._streaming_live is None
finally:
formatter._live = original_live
formatter._live_paused = original_paused_state
formatter._streaming_live = original_streaming_live
def test_pause_resume_exception_handling(self):
"""Test that resume is called even if exception occurs during human input."""
@@ -131,23 +114,18 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch(
"builtins.input", side_effect=KeyboardInterrupt("Test exception")
),
):
with pytest.raises(KeyboardInterrupt):
executor._ask_human_input("Test result")
try:
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch(
"builtins.input", side_effect=KeyboardInterrupt("Test exception")
),
):
with pytest.raises(KeyboardInterrupt):
executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
finally:
formatter._live_paused = original_paused_state
mock_pause.assert_called_once()
mock_resume.assert_called_once()
def test_training_mode_human_input(self):
"""Test human input in training mode."""
@@ -162,28 +140,25 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch.object(formatter.console, "print") as mock_console_print,
patch("builtins.input", return_value="training feedback"),
):
result = executor._ask_human_input("Test result")
try:
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch("builtins.input", return_value="training feedback"),
):
result = executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
assert result == "training feedback"
mock_pause.assert_called_once()
mock_resume.assert_called_once()
assert result == "training feedback"
executor._printer.print.assert_called()
call_args = [
call[1]["content"]
for call in executor._printer.print.call_args_list
]
training_prompt_found = any(
"TRAINING MODE" in content for content in call_args
)
assert training_prompt_found
finally:
formatter._live_paused = original_paused_state
# Verify the training panel was printed via formatter's console
mock_console_print.assert_called()
# Check that a Panel with training title was printed
call_args = mock_console_print.call_args_list
training_panel_found = any(
hasattr(call[0][0], "title") and "Training" in str(call[0][0].title)
for call in call_args
if call[0]
)
assert training_panel_found

View File

@@ -1,116 +1,107 @@
from unittest.mock import MagicMock, patch
from rich.tree import Tree
from rich.live import Live
from crewai.events.utils.console_formatter import ConsoleFormatter
class TestConsoleFormatterPauseResume:
"""Test ConsoleFormatter pause/resume functionality."""
"""Test ConsoleFormatter pause/resume functionality for HITL features."""
def test_pause_live_updates_with_active_session(self):
"""Test pausing when Live session is active."""
def test_pause_stops_active_streaming_session(self):
"""Test pausing stops an active streaming Live session."""
formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live)
formatter._live = mock_live
formatter._live_paused = False
formatter._streaming_live = mock_live
formatter.pause_live_updates()
mock_live.stop.assert_called_once()
assert formatter._live_paused
assert formatter._streaming_live is None
def test_pause_live_updates_when_already_paused(self):
"""Test pausing when already paused does nothing."""
def test_pause_is_safe_when_no_session(self):
"""Test pausing when no streaming session exists doesn't error."""
formatter = ConsoleFormatter()
formatter._streaming_live = None
# Should not raise
formatter.pause_live_updates()
assert formatter._streaming_live is None
def test_multiple_pauses_are_safe(self):
"""Test calling pause multiple times is safe."""
formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live)
formatter._live = mock_live
formatter._live_paused = True
formatter._streaming_live = mock_live
formatter.pause_live_updates()
mock_live.stop.assert_called_once()
assert formatter._streaming_live is None
mock_live.stop.assert_not_called()
assert formatter._live_paused
def test_pause_live_updates_with_no_session(self):
"""Test pausing when no Live session exists."""
formatter = ConsoleFormatter()
formatter._live = None
formatter._live_paused = False
# Second pause should not error (no session to stop)
formatter.pause_live_updates()
assert formatter._live_paused
def test_resume_live_updates_when_paused(self):
"""Test resuming when paused."""
def test_resume_is_safe(self):
"""Test resume method exists and doesn't error."""
formatter = ConsoleFormatter()
formatter._live_paused = True
# Should not raise
formatter.resume_live_updates()
assert not formatter._live_paused
def test_resume_live_updates_when_not_paused(self):
"""Test resuming when not paused does nothing."""
def test_streaming_after_pause_resume_creates_new_session(self):
"""Test that streaming after pause/resume creates new Live session."""
formatter = ConsoleFormatter()
formatter.verbose = True
formatter._live_paused = False
# Simulate having an active session
mock_live = MagicMock(spec=Live)
formatter._streaming_live = mock_live
# Pause stops the session
formatter.pause_live_updates()
assert formatter._streaming_live is None
# Resume (no-op, sessions created on demand)
formatter.resume_live_updates()
assert not formatter._live_paused
# After resume, streaming should be able to start a new session
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
mock_live_instance = MagicMock()
mock_live_class.return_value = mock_live_instance
def test_print_after_resume_restarts_live_session(self):
"""Test that printing a Tree after resume creates new Live session."""
# Simulate streaming chunk (this creates a new Live session)
formatter.handle_llm_stream_chunk("test chunk", call_type=None)
mock_live_class.assert_called_once()
mock_live_instance.start.assert_called_once()
assert formatter._streaming_live == mock_live_instance
def test_pause_resume_cycle_with_streaming(self):
"""Test full pause/resume cycle during streaming."""
formatter = ConsoleFormatter()
formatter._live_paused = True
formatter._live = None
formatter.resume_live_updates()
assert not formatter._live_paused
tree = Tree("Test")
formatter.verbose = True
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
mock_live_instance = MagicMock()
mock_live_class.return_value = mock_live_instance
formatter.print(tree)
# Start streaming
formatter.handle_llm_stream_chunk("chunk 1", call_type=None)
assert formatter._streaming_live == mock_live_instance
mock_live_class.assert_called_once()
mock_live_instance.start.assert_called_once()
assert formatter._live == mock_live_instance
# Pause should stop the session
formatter.pause_live_updates()
mock_live_instance.stop.assert_called_once()
assert formatter._streaming_live is None
def test_multiple_pause_resume_cycles(self):
"""Test multiple pause/resume cycles work correctly."""
formatter = ConsoleFormatter()
# Resume (no-op)
formatter.resume_live_updates()
mock_live = MagicMock(spec=Live)
formatter._live = mock_live
formatter._live_paused = False
# Create a new mock for the next session
mock_live_instance_2 = MagicMock()
mock_live_class.return_value = mock_live_instance_2
formatter.pause_live_updates()
assert formatter._live_paused
mock_live.stop.assert_called_once()
assert formatter._live is None # Live session should be cleared
formatter.resume_live_updates()
assert not formatter._live_paused
formatter.pause_live_updates()
assert formatter._live_paused
formatter.resume_live_updates()
assert not formatter._live_paused
def test_pause_resume_state_initialization(self):
"""Test that _live_paused is properly initialized."""
formatter = ConsoleFormatter()
assert hasattr(formatter, "_live_paused")
assert not formatter._live_paused
# Streaming again creates new session
formatter.handle_llm_stream_chunk("chunk 2", call_type=None)
assert formatter._streaming_live == mock_live_instance_2