ensure hitl works

This commit is contained in:
lorenzejay
2025-12-29 11:24:46 -08:00
parent 2e429b4ef6
commit 480708de2f
5 changed files with 200 additions and 199 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import time import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from crewai.agents.parser import AgentFinish
from crewai.events.event_listener import event_listener from crewai.events.event_listener import event_listener
from crewai.memory.entity.entity_memory_item import EntityMemoryItem from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
@@ -29,7 +30,7 @@ class CrewAgentExecutorMixin:
_i18n: I18N _i18n: I18N
_printer: Printer = Printer() _printer: Printer = Printer()
def _create_short_term_memory(self, output) -> None: def _create_short_term_memory(self, output: AgentFinish) -> None:
"""Create and save a short-term memory item if conditions are met.""" """Create and save a short-term memory item if conditions are met."""
if ( if (
self.crew self.crew
@@ -53,7 +54,7 @@ class CrewAgentExecutorMixin:
"error", f"Failed to add to short term memory: {e}" "error", f"Failed to add to short term memory: {e}"
) )
def _create_external_memory(self, output) -> None: def _create_external_memory(self, output: AgentFinish) -> None:
"""Create and save a external-term memory item if conditions are met.""" """Create and save a external-term memory item if conditions are met."""
if ( if (
self.crew self.crew
@@ -75,7 +76,7 @@ class CrewAgentExecutorMixin:
"error", f"Failed to add to external memory: {e}" "error", f"Failed to add to external memory: {e}"
) )
def _create_long_term_memory(self, output) -> None: def _create_long_term_memory(self, output: AgentFinish) -> None:
"""Create and save long-term and entity memory items based on evaluation.""" """Create and save long-term and entity memory items based on evaluation."""
if ( if (
self.crew self.crew
@@ -136,40 +137,50 @@ class CrewAgentExecutorMixin:
) )
def _ask_human_input(self, final_answer: str) -> str: def _ask_human_input(self, final_answer: str) -> str:
"""Prompt human input with mode-appropriate messaging.""" """Prompt human input with mode-appropriate messaging.
event_listener.formatter.pause_live_updates()
try:
self._printer.print(
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
)
Note: The final answer is already displayed via the AgentLogsExecutionEvent
panel, so we only show the feedback prompt here.
"""
from rich.panel import Panel
from rich.text import Text
formatter = event_listener.formatter
formatter.pause_live_updates()
try:
# Training mode prompt (single iteration) # Training mode prompt (single iteration)
if self.crew and getattr(self.crew, "_train", False): if self.crew and getattr(self.crew, "_train", False):
prompt = ( prompt_text = (
"\n\n=====\n" "TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
"This will be used to train better versions of the agent.\n" "This will be used to train better versions of the agent.\n"
"Please provide detailed feedback about the result quality and reasoning process.\n" "Please provide detailed feedback about the result quality and reasoning process."
"=====\n"
) )
title = "🎓 Training Feedback Required"
# Regular human-in-the-loop prompt (multiple iterations) # Regular human-in-the-loop prompt (multiple iterations)
else: else:
prompt = ( prompt_text = (
"\n\n=====\n" "Provide feedback on the Final Result above.\n\n"
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n" "• If you are happy with the result, simply hit Enter without typing anything.\n"
"Please follow these guidelines:\n" "• Otherwise, provide specific improvement requests.\n"
" - If you are happy with the result, simply hit Enter without typing anything.\n" "• You can provide multiple rounds of feedback until satisfied."
" - Otherwise, provide specific improvement requests.\n"
" - You can provide multiple rounds of feedback until satisfied.\n"
"=====\n"
) )
title = "💬 Human Feedback Required"
content = Text()
content.append(prompt_text, style="yellow")
prompt_panel = Panel(
content,
title=title,
border_style="yellow",
padding=(1, 2),
)
formatter.console.print(prompt_panel)
self._printer.print(content=prompt, color="bold_yellow")
response = input() response = input()
if response.strip() != "": if response.strip() != "":
self._printer.print( formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
content="\nProcessing your feedback...", color="cyan"
)
return response return response
finally: finally:
event_listener.formatter.resume_live_updates() formatter.resume_live_updates()

View File

@@ -541,7 +541,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.agent is None: if self.agent is None:
raise ValueError("Agent cannot be None") raise ValueError("Agent cannot be None")
crewai_event_bus.emit( future = crewai_event_bus.emit(
self.agent, self.agent,
AgentLogsExecutionEvent( AgentLogsExecutionEvent(
agent_role=self.agent.role, agent_role=self.agent.role,
@@ -551,6 +551,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
), ),
) )
if future is not None:
try:
future.result(timeout=5.0)
except Exception:
pass
def _handle_crew_training_output( def _handle_crew_training_output(
self, result: AgentFinish, human_feedback: str | None = None self, result: AgentFinish, human_feedback: str | None = None
) -> None: ) -> None:

View File

@@ -98,6 +98,24 @@ To enable tracing, do any one of these:
return return
self.console.print(*args, **kwargs) self.console.print(*args, **kwargs)
def pause_live_updates(self) -> None:
"""Pause Live session updates to allow for human input without interference.
This stops any active streaming Live session to prevent console refresh
interference during HITL (Human-in-the-Loop) user input.
"""
if self._streaming_live:
self._streaming_live.stop()
self._streaming_live = None
def resume_live_updates(self) -> None:
"""Resume Live session updates after human input is complete.
New streaming sessions will be created on-demand when needed.
This method exists for API compatibility with HITL callers.
"""
pass
def print_panel( def print_panel(
self, content: Text, title: str, style: str = "blue", is_flow: bool = False self, content: Text, title: str, style: str = "blue", is_flow: bool = False
) -> None: ) -> None:

View File

@@ -7,22 +7,19 @@ from crewai.events.event_listener import event_listener
class TestFlowHumanInputIntegration: class TestFlowHumanInputIntegration:
"""Test integration between Flow execution and human input functionality.""" """Test integration between Flow execution and human input functionality."""
def test_console_formatter_pause_resume_methods(self): def test_console_formatter_pause_resume_methods_exist(self):
"""Test that ConsoleFormatter pause/resume methods work correctly.""" """Test that ConsoleFormatter pause/resume methods exist and are callable."""
formatter = event_listener.formatter formatter = event_listener.formatter
original_paused_state = formatter._live_paused # Methods should exist and be callable
assert hasattr(formatter, "pause_live_updates")
try: assert hasattr(formatter, "resume_live_updates")
formatter._live_paused = False assert callable(formatter.pause_live_updates)
assert callable(formatter.resume_live_updates)
# Should not raise
formatter.pause_live_updates() formatter.pause_live_updates()
assert formatter._live_paused
formatter.resume_live_updates() formatter.resume_live_updates()
assert not formatter._live_paused
finally:
formatter._live_paused = original_paused_state
@patch("builtins.input", return_value="") @patch("builtins.input", return_value="")
def test_human_input_pauses_flow_updates(self, mock_input): def test_human_input_pauses_flow_updates(self, mock_input):
@@ -38,11 +35,6 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter formatter = event_listener.formatter
original_paused_state = formatter._live_paused
try:
formatter._live_paused = False
with ( with (
patch.object(formatter, "pause_live_updates") as mock_pause, patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume, patch.object(formatter, "resume_live_updates") as mock_resume,
@@ -53,8 +45,6 @@ class TestFlowHumanInputIntegration:
mock_resume.assert_called_once() mock_resume.assert_called_once()
mock_input.assert_called_once() mock_input.assert_called_once()
assert result == "" assert result == ""
finally:
formatter._live_paused = original_paused_state
@patch("builtins.input", side_effect=["feedback", ""]) @patch("builtins.input", side_effect=["feedback", ""])
def test_multiple_human_input_rounds(self, mock_input): def test_multiple_human_input_rounds(self, mock_input):
@@ -70,9 +60,6 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter formatter = event_listener.formatter
original_paused_state = formatter._live_paused
try:
pause_calls = [] pause_calls = []
resume_calls = [] resume_calls = []
@@ -96,27 +83,23 @@ class TestFlowHumanInputIntegration:
assert len(pause_calls) == 2 assert len(pause_calls) == 2
assert len(resume_calls) == 2 assert len(resume_calls) == 2
finally:
formatter._live_paused = original_paused_state
def test_pause_resume_with_no_live_session(self): def test_pause_resume_with_no_live_session(self):
"""Test pause/resume methods handle case when no Live session exists.""" """Test pause/resume methods handle case when no Live session exists."""
formatter = event_listener.formatter formatter = event_listener.formatter
original_live = formatter._live original_streaming_live = formatter._streaming_live
original_paused_state = formatter._live_paused
try: try:
formatter._live = None formatter._streaming_live = None
formatter._live_paused = False
# Should not raise when no session exists
formatter.pause_live_updates() formatter.pause_live_updates()
formatter.resume_live_updates() formatter.resume_live_updates()
assert not formatter._live_paused assert formatter._streaming_live is None
finally: finally:
formatter._live = original_live formatter._streaming_live = original_streaming_live
formatter._live_paused = original_paused_state
def test_pause_resume_exception_handling(self): def test_pause_resume_exception_handling(self):
"""Test that resume is called even if exception occurs during human input.""" """Test that resume is called even if exception occurs during human input."""
@@ -131,9 +114,6 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter formatter = event_listener.formatter
original_paused_state = formatter._live_paused
try:
with ( with (
patch.object(formatter, "pause_live_updates") as mock_pause, patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume, patch.object(formatter, "resume_live_updates") as mock_resume,
@@ -146,8 +126,6 @@ class TestFlowHumanInputIntegration:
mock_pause.assert_called_once() mock_pause.assert_called_once()
mock_resume.assert_called_once() mock_resume.assert_called_once()
finally:
formatter._live_paused = original_paused_state
def test_training_mode_human_input(self): def test_training_mode_human_input(self):
"""Test human input in training mode.""" """Test human input in training mode."""
@@ -162,12 +140,10 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter formatter = event_listener.formatter
original_paused_state = formatter._live_paused
try:
with ( with (
patch.object(formatter, "pause_live_updates") as mock_pause, patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume, patch.object(formatter, "resume_live_updates") as mock_resume,
patch.object(formatter.console, "print") as mock_console_print,
patch("builtins.input", return_value="training feedback"), patch("builtins.input", return_value="training feedback"),
): ):
result = executor._ask_human_input("Test result") result = executor._ask_human_input("Test result")
@@ -176,14 +152,13 @@ class TestFlowHumanInputIntegration:
mock_resume.assert_called_once() mock_resume.assert_called_once()
assert result == "training feedback" assert result == "training feedback"
executor._printer.print.assert_called() # Verify the training panel was printed via formatter's console
call_args = [ mock_console_print.assert_called()
call[1]["content"] # Check that a Panel with training title was printed
for call in executor._printer.print.call_args_list call_args = mock_console_print.call_args_list
] training_panel_found = any(
training_prompt_found = any( hasattr(call[0][0], "title") and "Training" in str(call[0][0].title)
"TRAINING MODE" in content for content in call_args for call in call_args
if call[0]
) )
assert training_prompt_found assert training_panel_found
finally:
formatter._live_paused = original_paused_state

View File

@@ -1,116 +1,107 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from rich.tree import Tree
from rich.live import Live from rich.live import Live
from crewai.events.utils.console_formatter import ConsoleFormatter from crewai.events.utils.console_formatter import ConsoleFormatter
class TestConsoleFormatterPauseResume: class TestConsoleFormatterPauseResume:
"""Test ConsoleFormatter pause/resume functionality.""" """Test ConsoleFormatter pause/resume functionality for HITL features."""
def test_pause_live_updates_with_active_session(self): def test_pause_stops_active_streaming_session(self):
"""Test pausing when Live session is active.""" """Test pausing stops an active streaming Live session."""
formatter = ConsoleFormatter() formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live) mock_live = MagicMock(spec=Live)
formatter._live = mock_live formatter._streaming_live = mock_live
formatter._live_paused = False
formatter.pause_live_updates() formatter.pause_live_updates()
mock_live.stop.assert_called_once() mock_live.stop.assert_called_once()
assert formatter._live_paused assert formatter._streaming_live is None
def test_pause_live_updates_when_already_paused(self): def test_pause_is_safe_when_no_session(self):
"""Test pausing when already paused does nothing.""" """Test pausing when no streaming session exists doesn't error."""
formatter = ConsoleFormatter()
formatter._streaming_live = None
# Should not raise
formatter.pause_live_updates()
assert formatter._streaming_live is None
def test_multiple_pauses_are_safe(self):
"""Test calling pause multiple times is safe."""
formatter = ConsoleFormatter() formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live) mock_live = MagicMock(spec=Live)
formatter._live = mock_live formatter._streaming_live = mock_live
formatter._live_paused = True
formatter.pause_live_updates() formatter.pause_live_updates()
mock_live.stop.assert_called_once()
assert formatter._streaming_live is None
mock_live.stop.assert_not_called() # Second pause should not error (no session to stop)
assert formatter._live_paused
def test_pause_live_updates_with_no_session(self):
"""Test pausing when no Live session exists."""
formatter = ConsoleFormatter()
formatter._live = None
formatter._live_paused = False
formatter.pause_live_updates() formatter.pause_live_updates()
assert formatter._live_paused def test_resume_is_safe(self):
"""Test resume method exists and doesn't error."""
def test_resume_live_updates_when_paused(self):
"""Test resuming when paused."""
formatter = ConsoleFormatter() formatter = ConsoleFormatter()
formatter._live_paused = True # Should not raise
formatter.resume_live_updates() formatter.resume_live_updates()
assert not formatter._live_paused def test_streaming_after_pause_resume_creates_new_session(self):
"""Test that streaming after pause/resume creates new Live session."""
def test_resume_live_updates_when_not_paused(self):
"""Test resuming when not paused does nothing."""
formatter = ConsoleFormatter() formatter = ConsoleFormatter()
formatter.verbose = True
formatter._live_paused = False # Simulate having an active session
mock_live = MagicMock(spec=Live)
formatter._streaming_live = mock_live
# Pause stops the session
formatter.pause_live_updates()
assert formatter._streaming_live is None
# Resume (no-op, sessions created on demand)
formatter.resume_live_updates() formatter.resume_live_updates()
assert not formatter._live_paused # After resume, streaming should be able to start a new session
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
mock_live_instance = MagicMock()
mock_live_class.return_value = mock_live_instance
def test_print_after_resume_restarts_live_session(self): # Simulate streaming chunk (this creates a new Live session)
"""Test that printing a Tree after resume creates new Live session.""" formatter.handle_llm_stream_chunk("test chunk", call_type=None)
mock_live_class.assert_called_once()
mock_live_instance.start.assert_called_once()
assert formatter._streaming_live == mock_live_instance
def test_pause_resume_cycle_with_streaming(self):
"""Test full pause/resume cycle during streaming."""
formatter = ConsoleFormatter() formatter = ConsoleFormatter()
formatter.verbose = True
formatter._live_paused = True
formatter._live = None
formatter.resume_live_updates()
assert not formatter._live_paused
tree = Tree("Test")
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class: with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
mock_live_instance = MagicMock() mock_live_instance = MagicMock()
mock_live_class.return_value = mock_live_instance mock_live_class.return_value = mock_live_instance
formatter.print(tree) # Start streaming
formatter.handle_llm_stream_chunk("chunk 1", call_type=None)
mock_live_class.assert_called_once() assert formatter._streaming_live == mock_live_instance
mock_live_instance.start.assert_called_once()
assert formatter._live == mock_live_instance
def test_multiple_pause_resume_cycles(self):
"""Test multiple pause/resume cycles work correctly."""
formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live)
formatter._live = mock_live
formatter._live_paused = False
# Pause should stop the session
formatter.pause_live_updates() formatter.pause_live_updates()
assert formatter._live_paused mock_live_instance.stop.assert_called_once()
mock_live.stop.assert_called_once() assert formatter._streaming_live is None
assert formatter._live is None # Live session should be cleared
# Resume (no-op)
formatter.resume_live_updates() formatter.resume_live_updates()
assert not formatter._live_paused
formatter.pause_live_updates() # Create a new mock for the next session
assert formatter._live_paused mock_live_instance_2 = MagicMock()
mock_live_class.return_value = mock_live_instance_2
formatter.resume_live_updates() # Streaming again creates new session
assert not formatter._live_paused formatter.handle_llm_stream_chunk("chunk 2", call_type=None)
assert formatter._streaming_live == mock_live_instance_2
def test_pause_resume_state_initialization(self):
"""Test that _live_paused is properly initialized."""
formatter = ConsoleFormatter()
assert hasattr(formatter, "_live_paused")
assert not formatter._live_paused