mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 07:38:29 +00:00
ensure hitl works
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.agents.parser import AgentFinish
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
@@ -29,7 +30,7 @@ class CrewAgentExecutorMixin:
|
||||
_i18n: I18N
|
||||
_printer: Printer = Printer()
|
||||
|
||||
def _create_short_term_memory(self, output) -> None:
|
||||
def _create_short_term_memory(self, output: AgentFinish) -> None:
|
||||
"""Create and save a short-term memory item if conditions are met."""
|
||||
if (
|
||||
self.crew
|
||||
@@ -53,7 +54,7 @@ class CrewAgentExecutorMixin:
|
||||
"error", f"Failed to add to short term memory: {e}"
|
||||
)
|
||||
|
||||
def _create_external_memory(self, output) -> None:
|
||||
def _create_external_memory(self, output: AgentFinish) -> None:
|
||||
"""Create and save a external-term memory item if conditions are met."""
|
||||
if (
|
||||
self.crew
|
||||
@@ -75,7 +76,7 @@ class CrewAgentExecutorMixin:
|
||||
"error", f"Failed to add to external memory: {e}"
|
||||
)
|
||||
|
||||
def _create_long_term_memory(self, output) -> None:
|
||||
def _create_long_term_memory(self, output: AgentFinish) -> None:
|
||||
"""Create and save long-term and entity memory items based on evaluation."""
|
||||
if (
|
||||
self.crew
|
||||
@@ -136,40 +137,50 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
|
||||
def _ask_human_input(self, final_answer: str) -> str:
|
||||
"""Prompt human input with mode-appropriate messaging."""
|
||||
event_listener.formatter.pause_live_updates()
|
||||
try:
|
||||
self._printer.print(
|
||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
||||
)
|
||||
"""Prompt human input with mode-appropriate messaging.
|
||||
|
||||
Note: The final answer is already displayed via the AgentLogsExecutionEvent
|
||||
panel, so we only show the feedback prompt here.
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
formatter = event_listener.formatter
|
||||
formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
# Training mode prompt (single iteration)
|
||||
if self.crew and getattr(self.crew, "_train", False):
|
||||
prompt = (
|
||||
"\n\n=====\n"
|
||||
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
|
||||
prompt_text = (
|
||||
"TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
|
||||
"This will be used to train better versions of the agent.\n"
|
||||
"Please provide detailed feedback about the result quality and reasoning process.\n"
|
||||
"=====\n"
|
||||
"Please provide detailed feedback about the result quality and reasoning process."
|
||||
)
|
||||
title = "🎓 Training Feedback Required"
|
||||
# Regular human-in-the-loop prompt (multiple iterations)
|
||||
else:
|
||||
prompt = (
|
||||
"\n\n=====\n"
|
||||
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
|
||||
"Please follow these guidelines:\n"
|
||||
" - If you are happy with the result, simply hit Enter without typing anything.\n"
|
||||
" - Otherwise, provide specific improvement requests.\n"
|
||||
" - You can provide multiple rounds of feedback until satisfied.\n"
|
||||
"=====\n"
|
||||
prompt_text = (
|
||||
"Provide feedback on the Final Result above.\n\n"
|
||||
"• If you are happy with the result, simply hit Enter without typing anything.\n"
|
||||
"• Otherwise, provide specific improvement requests.\n"
|
||||
"• You can provide multiple rounds of feedback until satisfied."
|
||||
)
|
||||
title = "💬 Human Feedback Required"
|
||||
|
||||
content = Text()
|
||||
content.append(prompt_text, style="yellow")
|
||||
|
||||
prompt_panel = Panel(
|
||||
content,
|
||||
title=title,
|
||||
border_style="yellow",
|
||||
padding=(1, 2),
|
||||
)
|
||||
formatter.console.print(prompt_panel)
|
||||
|
||||
self._printer.print(content=prompt, color="bold_yellow")
|
||||
response = input()
|
||||
if response.strip() != "":
|
||||
self._printer.print(
|
||||
content="\nProcessing your feedback...", color="cyan"
|
||||
)
|
||||
formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
|
||||
return response
|
||||
finally:
|
||||
event_listener.formatter.resume_live_updates()
|
||||
formatter.resume_live_updates()
|
||||
|
||||
@@ -541,7 +541,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent cannot be None")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
future = crewai_event_bus.emit(
|
||||
self.agent,
|
||||
AgentLogsExecutionEvent(
|
||||
agent_role=self.agent.role,
|
||||
@@ -551,6 +551,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
),
|
||||
)
|
||||
|
||||
if future is not None:
|
||||
try:
|
||||
future.result(timeout=5.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _handle_crew_training_output(
|
||||
self, result: AgentFinish, human_feedback: str | None = None
|
||||
) -> None:
|
||||
|
||||
@@ -98,6 +98,24 @@ To enable tracing, do any one of these:
|
||||
return
|
||||
self.console.print(*args, **kwargs)
|
||||
|
||||
def pause_live_updates(self) -> None:
|
||||
"""Pause Live session updates to allow for human input without interference.
|
||||
|
||||
This stops any active streaming Live session to prevent console refresh
|
||||
interference during HITL (Human-in-the-Loop) user input.
|
||||
"""
|
||||
if self._streaming_live:
|
||||
self._streaming_live.stop()
|
||||
self._streaming_live = None
|
||||
|
||||
def resume_live_updates(self) -> None:
|
||||
"""Resume Live session updates after human input is complete.
|
||||
|
||||
New streaming sessions will be created on-demand when needed.
|
||||
This method exists for API compatibility with HITL callers.
|
||||
"""
|
||||
pass
|
||||
|
||||
def print_panel(
|
||||
self, content: Text, title: str, style: str = "blue", is_flow: bool = False
|
||||
) -> None:
|
||||
|
||||
@@ -7,22 +7,19 @@ from crewai.events.event_listener import event_listener
|
||||
class TestFlowHumanInputIntegration:
|
||||
"""Test integration between Flow execution and human input functionality."""
|
||||
|
||||
def test_console_formatter_pause_resume_methods(self):
|
||||
"""Test that ConsoleFormatter pause/resume methods work correctly."""
|
||||
def test_console_formatter_pause_resume_methods_exist(self):
|
||||
"""Test that ConsoleFormatter pause/resume methods exist and are callable."""
|
||||
formatter = event_listener.formatter
|
||||
|
||||
original_paused_state = formatter._live_paused
|
||||
# Methods should exist and be callable
|
||||
assert hasattr(formatter, "pause_live_updates")
|
||||
assert hasattr(formatter, "resume_live_updates")
|
||||
assert callable(formatter.pause_live_updates)
|
||||
assert callable(formatter.resume_live_updates)
|
||||
|
||||
try:
|
||||
formatter._live_paused = False
|
||||
|
||||
formatter.pause_live_updates()
|
||||
assert formatter._live_paused
|
||||
|
||||
formatter.resume_live_updates()
|
||||
assert not formatter._live_paused
|
||||
finally:
|
||||
formatter._live_paused = original_paused_state
|
||||
# Should not raise
|
||||
formatter.pause_live_updates()
|
||||
formatter.resume_live_updates()
|
||||
|
||||
@patch("builtins.input", return_value="")
|
||||
def test_human_input_pauses_flow_updates(self, mock_input):
|
||||
@@ -38,23 +35,16 @@ class TestFlowHumanInputIntegration:
|
||||
|
||||
formatter = event_listener.formatter
|
||||
|
||||
original_paused_state = formatter._live_paused
|
||||
with (
|
||||
patch.object(formatter, "pause_live_updates") as mock_pause,
|
||||
patch.object(formatter, "resume_live_updates") as mock_resume,
|
||||
):
|
||||
result = executor._ask_human_input("Test result")
|
||||
|
||||
try:
|
||||
formatter._live_paused = False
|
||||
|
||||
with (
|
||||
patch.object(formatter, "pause_live_updates") as mock_pause,
|
||||
patch.object(formatter, "resume_live_updates") as mock_resume,
|
||||
):
|
||||
result = executor._ask_human_input("Test result")
|
||||
|
||||
mock_pause.assert_called_once()
|
||||
mock_resume.assert_called_once()
|
||||
mock_input.assert_called_once()
|
||||
assert result == ""
|
||||
finally:
|
||||
formatter._live_paused = original_paused_state
|
||||
mock_pause.assert_called_once()
|
||||
mock_resume.assert_called_once()
|
||||
mock_input.assert_called_once()
|
||||
assert result == ""
|
||||
|
||||
@patch("builtins.input", side_effect=["feedback", ""])
|
||||
def test_multiple_human_input_rounds(self, mock_input):
|
||||
@@ -70,53 +60,46 @@ class TestFlowHumanInputIntegration:
|
||||
|
||||
formatter = event_listener.formatter
|
||||
|
||||
original_paused_state = formatter._live_paused
|
||||
pause_calls = []
|
||||
resume_calls = []
|
||||
|
||||
try:
|
||||
pause_calls = []
|
||||
resume_calls = []
|
||||
def track_pause():
|
||||
pause_calls.append(True)
|
||||
|
||||
def track_pause():
|
||||
pause_calls.append(True)
|
||||
def track_resume():
|
||||
resume_calls.append(True)
|
||||
|
||||
def track_resume():
|
||||
resume_calls.append(True)
|
||||
with (
|
||||
patch.object(formatter, "pause_live_updates", side_effect=track_pause),
|
||||
patch.object(
|
||||
formatter, "resume_live_updates", side_effect=track_resume
|
||||
),
|
||||
):
|
||||
result1 = executor._ask_human_input("Test result 1")
|
||||
assert result1 == "feedback"
|
||||
|
||||
with (
|
||||
patch.object(formatter, "pause_live_updates", side_effect=track_pause),
|
||||
patch.object(
|
||||
formatter, "resume_live_updates", side_effect=track_resume
|
||||
),
|
||||
):
|
||||
result1 = executor._ask_human_input("Test result 1")
|
||||
assert result1 == "feedback"
|
||||
result2 = executor._ask_human_input("Test result 2")
|
||||
assert result2 == ""
|
||||
|
||||
result2 = executor._ask_human_input("Test result 2")
|
||||
assert result2 == ""
|
||||
|
||||
assert len(pause_calls) == 2
|
||||
assert len(resume_calls) == 2
|
||||
finally:
|
||||
formatter._live_paused = original_paused_state
|
||||
assert len(pause_calls) == 2
|
||||
assert len(resume_calls) == 2
|
||||
|
||||
def test_pause_resume_with_no_live_session(self):
|
||||
"""Test pause/resume methods handle case when no Live session exists."""
|
||||
formatter = event_listener.formatter
|
||||
|
||||
original_live = formatter._live
|
||||
original_paused_state = formatter._live_paused
|
||||
original_streaming_live = formatter._streaming_live
|
||||
|
||||
try:
|
||||
formatter._live = None
|
||||
formatter._live_paused = False
|
||||
formatter._streaming_live = None
|
||||
|
||||
# Should not raise when no session exists
|
||||
formatter.pause_live_updates()
|
||||
formatter.resume_live_updates()
|
||||
|
||||
assert not formatter._live_paused
|
||||
assert formatter._streaming_live is None
|
||||
finally:
|
||||
formatter._live = original_live
|
||||
formatter._live_paused = original_paused_state
|
||||
formatter._streaming_live = original_streaming_live
|
||||
|
||||
def test_pause_resume_exception_handling(self):
|
||||
"""Test that resume is called even if exception occurs during human input."""
|
||||
@@ -131,23 +114,18 @@ class TestFlowHumanInputIntegration:
|
||||
|
||||
formatter = event_listener.formatter
|
||||
|
||||
original_paused_state = formatter._live_paused
|
||||
with (
|
||||
patch.object(formatter, "pause_live_updates") as mock_pause,
|
||||
patch.object(formatter, "resume_live_updates") as mock_resume,
|
||||
patch(
|
||||
"builtins.input", side_effect=KeyboardInterrupt("Test exception")
|
||||
),
|
||||
):
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
executor._ask_human_input("Test result")
|
||||
|
||||
try:
|
||||
with (
|
||||
patch.object(formatter, "pause_live_updates") as mock_pause,
|
||||
patch.object(formatter, "resume_live_updates") as mock_resume,
|
||||
patch(
|
||||
"builtins.input", side_effect=KeyboardInterrupt("Test exception")
|
||||
),
|
||||
):
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
executor._ask_human_input("Test result")
|
||||
|
||||
mock_pause.assert_called_once()
|
||||
mock_resume.assert_called_once()
|
||||
finally:
|
||||
formatter._live_paused = original_paused_state
|
||||
mock_pause.assert_called_once()
|
||||
mock_resume.assert_called_once()
|
||||
|
||||
def test_training_mode_human_input(self):
|
||||
"""Test human input in training mode."""
|
||||
@@ -162,28 +140,25 @@ class TestFlowHumanInputIntegration:
|
||||
|
||||
formatter = event_listener.formatter
|
||||
|
||||
original_paused_state = formatter._live_paused
|
||||
with (
|
||||
patch.object(formatter, "pause_live_updates") as mock_pause,
|
||||
patch.object(formatter, "resume_live_updates") as mock_resume,
|
||||
patch.object(formatter.console, "print") as mock_console_print,
|
||||
patch("builtins.input", return_value="training feedback"),
|
||||
):
|
||||
result = executor._ask_human_input("Test result")
|
||||
|
||||
try:
|
||||
with (
|
||||
patch.object(formatter, "pause_live_updates") as mock_pause,
|
||||
patch.object(formatter, "resume_live_updates") as mock_resume,
|
||||
patch("builtins.input", return_value="training feedback"),
|
||||
):
|
||||
result = executor._ask_human_input("Test result")
|
||||
mock_pause.assert_called_once()
|
||||
mock_resume.assert_called_once()
|
||||
assert result == "training feedback"
|
||||
|
||||
mock_pause.assert_called_once()
|
||||
mock_resume.assert_called_once()
|
||||
assert result == "training feedback"
|
||||
|
||||
executor._printer.print.assert_called()
|
||||
call_args = [
|
||||
call[1]["content"]
|
||||
for call in executor._printer.print.call_args_list
|
||||
]
|
||||
training_prompt_found = any(
|
||||
"TRAINING MODE" in content for content in call_args
|
||||
)
|
||||
assert training_prompt_found
|
||||
finally:
|
||||
formatter._live_paused = original_paused_state
|
||||
# Verify the training panel was printed via formatter's console
|
||||
mock_console_print.assert_called()
|
||||
# Check that a Panel with training title was printed
|
||||
call_args = mock_console_print.call_args_list
|
||||
training_panel_found = any(
|
||||
hasattr(call[0][0], "title") and "Training" in str(call[0][0].title)
|
||||
for call in call_args
|
||||
if call[0]
|
||||
)
|
||||
assert training_panel_found
|
||||
|
||||
@@ -1,116 +1,107 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from rich.tree import Tree
|
||||
from rich.live import Live
|
||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||
|
||||
|
||||
class TestConsoleFormatterPauseResume:
|
||||
"""Test ConsoleFormatter pause/resume functionality."""
|
||||
"""Test ConsoleFormatter pause/resume functionality for HITL features."""
|
||||
|
||||
def test_pause_live_updates_with_active_session(self):
|
||||
"""Test pausing when Live session is active."""
|
||||
def test_pause_stops_active_streaming_session(self):
|
||||
"""Test pausing stops an active streaming Live session."""
|
||||
formatter = ConsoleFormatter()
|
||||
|
||||
mock_live = MagicMock(spec=Live)
|
||||
formatter._live = mock_live
|
||||
formatter._live_paused = False
|
||||
formatter._streaming_live = mock_live
|
||||
|
||||
formatter.pause_live_updates()
|
||||
|
||||
mock_live.stop.assert_called_once()
|
||||
assert formatter._live_paused
|
||||
assert formatter._streaming_live is None
|
||||
|
||||
def test_pause_live_updates_when_already_paused(self):
|
||||
"""Test pausing when already paused does nothing."""
|
||||
def test_pause_is_safe_when_no_session(self):
|
||||
"""Test pausing when no streaming session exists doesn't error."""
|
||||
formatter = ConsoleFormatter()
|
||||
formatter._streaming_live = None
|
||||
|
||||
# Should not raise
|
||||
formatter.pause_live_updates()
|
||||
|
||||
assert formatter._streaming_live is None
|
||||
|
||||
def test_multiple_pauses_are_safe(self):
|
||||
"""Test calling pause multiple times is safe."""
|
||||
formatter = ConsoleFormatter()
|
||||
|
||||
mock_live = MagicMock(spec=Live)
|
||||
formatter._live = mock_live
|
||||
formatter._live_paused = True
|
||||
formatter._streaming_live = mock_live
|
||||
|
||||
formatter.pause_live_updates()
|
||||
mock_live.stop.assert_called_once()
|
||||
assert formatter._streaming_live is None
|
||||
|
||||
mock_live.stop.assert_not_called()
|
||||
assert formatter._live_paused
|
||||
|
||||
def test_pause_live_updates_with_no_session(self):
|
||||
"""Test pausing when no Live session exists."""
|
||||
formatter = ConsoleFormatter()
|
||||
|
||||
formatter._live = None
|
||||
formatter._live_paused = False
|
||||
|
||||
# Second pause should not error (no session to stop)
|
||||
formatter.pause_live_updates()
|
||||
|
||||
assert formatter._live_paused
|
||||
|
||||
def test_resume_live_updates_when_paused(self):
|
||||
"""Test resuming when paused."""
|
||||
def test_resume_is_safe(self):
|
||||
"""Test resume method exists and doesn't error."""
|
||||
formatter = ConsoleFormatter()
|
||||
|
||||
formatter._live_paused = True
|
||||
|
||||
# Should not raise
|
||||
formatter.resume_live_updates()
|
||||
|
||||
assert not formatter._live_paused
|
||||
|
||||
def test_resume_live_updates_when_not_paused(self):
|
||||
"""Test resuming when not paused does nothing."""
|
||||
def test_streaming_after_pause_resume_creates_new_session(self):
|
||||
"""Test that streaming after pause/resume creates new Live session."""
|
||||
formatter = ConsoleFormatter()
|
||||
formatter.verbose = True
|
||||
|
||||
formatter._live_paused = False
|
||||
# Simulate having an active session
|
||||
mock_live = MagicMock(spec=Live)
|
||||
formatter._streaming_live = mock_live
|
||||
|
||||
# Pause stops the session
|
||||
formatter.pause_live_updates()
|
||||
assert formatter._streaming_live is None
|
||||
|
||||
# Resume (no-op, sessions created on demand)
|
||||
formatter.resume_live_updates()
|
||||
|
||||
assert not formatter._live_paused
|
||||
# After resume, streaming should be able to start a new session
|
||||
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
|
||||
mock_live_instance = MagicMock()
|
||||
mock_live_class.return_value = mock_live_instance
|
||||
|
||||
def test_print_after_resume_restarts_live_session(self):
|
||||
"""Test that printing a Tree after resume creates new Live session."""
|
||||
# Simulate streaming chunk (this creates a new Live session)
|
||||
formatter.handle_llm_stream_chunk("test chunk", call_type=None)
|
||||
|
||||
mock_live_class.assert_called_once()
|
||||
mock_live_instance.start.assert_called_once()
|
||||
assert formatter._streaming_live == mock_live_instance
|
||||
|
||||
def test_pause_resume_cycle_with_streaming(self):
|
||||
"""Test full pause/resume cycle during streaming."""
|
||||
formatter = ConsoleFormatter()
|
||||
|
||||
formatter._live_paused = True
|
||||
formatter._live = None
|
||||
|
||||
formatter.resume_live_updates()
|
||||
assert not formatter._live_paused
|
||||
|
||||
tree = Tree("Test")
|
||||
formatter.verbose = True
|
||||
|
||||
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
|
||||
mock_live_instance = MagicMock()
|
||||
mock_live_class.return_value = mock_live_instance
|
||||
|
||||
formatter.print(tree)
|
||||
# Start streaming
|
||||
formatter.handle_llm_stream_chunk("chunk 1", call_type=None)
|
||||
assert formatter._streaming_live == mock_live_instance
|
||||
|
||||
mock_live_class.assert_called_once()
|
||||
mock_live_instance.start.assert_called_once()
|
||||
assert formatter._live == mock_live_instance
|
||||
# Pause should stop the session
|
||||
formatter.pause_live_updates()
|
||||
mock_live_instance.stop.assert_called_once()
|
||||
assert formatter._streaming_live is None
|
||||
|
||||
def test_multiple_pause_resume_cycles(self):
|
||||
"""Test multiple pause/resume cycles work correctly."""
|
||||
formatter = ConsoleFormatter()
|
||||
# Resume (no-op)
|
||||
formatter.resume_live_updates()
|
||||
|
||||
mock_live = MagicMock(spec=Live)
|
||||
formatter._live = mock_live
|
||||
formatter._live_paused = False
|
||||
# Create a new mock for the next session
|
||||
mock_live_instance_2 = MagicMock()
|
||||
mock_live_class.return_value = mock_live_instance_2
|
||||
|
||||
formatter.pause_live_updates()
|
||||
assert formatter._live_paused
|
||||
mock_live.stop.assert_called_once()
|
||||
assert formatter._live is None # Live session should be cleared
|
||||
|
||||
formatter.resume_live_updates()
|
||||
assert not formatter._live_paused
|
||||
|
||||
formatter.pause_live_updates()
|
||||
assert formatter._live_paused
|
||||
|
||||
formatter.resume_live_updates()
|
||||
assert not formatter._live_paused
|
||||
|
||||
def test_pause_resume_state_initialization(self):
|
||||
"""Test that _live_paused is properly initialized."""
|
||||
formatter = ConsoleFormatter()
|
||||
|
||||
assert hasattr(formatter, "_live_paused")
|
||||
assert not formatter._live_paused
|
||||
# Streaming again creates new session
|
||||
formatter.handle_llm_stream_chunk("chunk 2", call_type=None)
|
||||
assert formatter._streaming_live == mock_live_instance_2
|
||||
|
||||
Reference in New Issue
Block a user