mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 15:52:34 +00:00
ensure hitl works
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from crewai.agents.parser import AgentFinish
|
||||||
from crewai.events.event_listener import event_listener
|
from crewai.events.event_listener import event_listener
|
||||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||||
@@ -29,7 +30,7 @@ class CrewAgentExecutorMixin:
|
|||||||
_i18n: I18N
|
_i18n: I18N
|
||||||
_printer: Printer = Printer()
|
_printer: Printer = Printer()
|
||||||
|
|
||||||
def _create_short_term_memory(self, output) -> None:
|
def _create_short_term_memory(self, output: AgentFinish) -> None:
|
||||||
"""Create and save a short-term memory item if conditions are met."""
|
"""Create and save a short-term memory item if conditions are met."""
|
||||||
if (
|
if (
|
||||||
self.crew
|
self.crew
|
||||||
@@ -53,7 +54,7 @@ class CrewAgentExecutorMixin:
|
|||||||
"error", f"Failed to add to short term memory: {e}"
|
"error", f"Failed to add to short term memory: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_external_memory(self, output) -> None:
|
def _create_external_memory(self, output: AgentFinish) -> None:
|
||||||
"""Create and save a external-term memory item if conditions are met."""
|
"""Create and save a external-term memory item if conditions are met."""
|
||||||
if (
|
if (
|
||||||
self.crew
|
self.crew
|
||||||
@@ -75,7 +76,7 @@ class CrewAgentExecutorMixin:
|
|||||||
"error", f"Failed to add to external memory: {e}"
|
"error", f"Failed to add to external memory: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_long_term_memory(self, output) -> None:
|
def _create_long_term_memory(self, output: AgentFinish) -> None:
|
||||||
"""Create and save long-term and entity memory items based on evaluation."""
|
"""Create and save long-term and entity memory items based on evaluation."""
|
||||||
if (
|
if (
|
||||||
self.crew
|
self.crew
|
||||||
@@ -136,40 +137,50 @@ class CrewAgentExecutorMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _ask_human_input(self, final_answer: str) -> str:
|
def _ask_human_input(self, final_answer: str) -> str:
|
||||||
"""Prompt human input with mode-appropriate messaging."""
|
"""Prompt human input with mode-appropriate messaging.
|
||||||
event_listener.formatter.pause_live_updates()
|
|
||||||
try:
|
|
||||||
self._printer.print(
|
|
||||||
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
Note: The final answer is already displayed via the AgentLogsExecutionEvent
|
||||||
|
panel, so we only show the feedback prompt here.
|
||||||
|
"""
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
formatter = event_listener.formatter
|
||||||
|
formatter.pause_live_updates()
|
||||||
|
|
||||||
|
try:
|
||||||
# Training mode prompt (single iteration)
|
# Training mode prompt (single iteration)
|
||||||
if self.crew and getattr(self.crew, "_train", False):
|
if self.crew and getattr(self.crew, "_train", False):
|
||||||
prompt = (
|
prompt_text = (
|
||||||
"\n\n=====\n"
|
"TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
|
||||||
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
|
|
||||||
"This will be used to train better versions of the agent.\n"
|
"This will be used to train better versions of the agent.\n"
|
||||||
"Please provide detailed feedback about the result quality and reasoning process.\n"
|
"Please provide detailed feedback about the result quality and reasoning process."
|
||||||
"=====\n"
|
|
||||||
)
|
)
|
||||||
|
title = "🎓 Training Feedback Required"
|
||||||
# Regular human-in-the-loop prompt (multiple iterations)
|
# Regular human-in-the-loop prompt (multiple iterations)
|
||||||
else:
|
else:
|
||||||
prompt = (
|
prompt_text = (
|
||||||
"\n\n=====\n"
|
"Provide feedback on the Final Result above.\n\n"
|
||||||
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
|
"• If you are happy with the result, simply hit Enter without typing anything.\n"
|
||||||
"Please follow these guidelines:\n"
|
"• Otherwise, provide specific improvement requests.\n"
|
||||||
" - If you are happy with the result, simply hit Enter without typing anything.\n"
|
"• You can provide multiple rounds of feedback until satisfied."
|
||||||
" - Otherwise, provide specific improvement requests.\n"
|
|
||||||
" - You can provide multiple rounds of feedback until satisfied.\n"
|
|
||||||
"=====\n"
|
|
||||||
)
|
)
|
||||||
|
title = "💬 Human Feedback Required"
|
||||||
|
|
||||||
|
content = Text()
|
||||||
|
content.append(prompt_text, style="yellow")
|
||||||
|
|
||||||
|
prompt_panel = Panel(
|
||||||
|
content,
|
||||||
|
title=title,
|
||||||
|
border_style="yellow",
|
||||||
|
padding=(1, 2),
|
||||||
|
)
|
||||||
|
formatter.console.print(prompt_panel)
|
||||||
|
|
||||||
self._printer.print(content=prompt, color="bold_yellow")
|
|
||||||
response = input()
|
response = input()
|
||||||
if response.strip() != "":
|
if response.strip() != "":
|
||||||
self._printer.print(
|
formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
|
||||||
content="\nProcessing your feedback...", color="cyan"
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
finally:
|
finally:
|
||||||
event_listener.formatter.resume_live_updates()
|
formatter.resume_live_updates()
|
||||||
|
|||||||
@@ -541,7 +541,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
if self.agent is None:
|
if self.agent is None:
|
||||||
raise ValueError("Agent cannot be None")
|
raise ValueError("Agent cannot be None")
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
future = crewai_event_bus.emit(
|
||||||
self.agent,
|
self.agent,
|
||||||
AgentLogsExecutionEvent(
|
AgentLogsExecutionEvent(
|
||||||
agent_role=self.agent.role,
|
agent_role=self.agent.role,
|
||||||
@@ -551,6 +551,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if future is not None:
|
||||||
|
try:
|
||||||
|
future.result(timeout=5.0)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
def _handle_crew_training_output(
|
def _handle_crew_training_output(
|
||||||
self, result: AgentFinish, human_feedback: str | None = None
|
self, result: AgentFinish, human_feedback: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -98,6 +98,24 @@ To enable tracing, do any one of these:
|
|||||||
return
|
return
|
||||||
self.console.print(*args, **kwargs)
|
self.console.print(*args, **kwargs)
|
||||||
|
|
||||||
|
def pause_live_updates(self) -> None:
|
||||||
|
"""Pause Live session updates to allow for human input without interference.
|
||||||
|
|
||||||
|
This stops any active streaming Live session to prevent console refresh
|
||||||
|
interference during HITL (Human-in-the-Loop) user input.
|
||||||
|
"""
|
||||||
|
if self._streaming_live:
|
||||||
|
self._streaming_live.stop()
|
||||||
|
self._streaming_live = None
|
||||||
|
|
||||||
|
def resume_live_updates(self) -> None:
|
||||||
|
"""Resume Live session updates after human input is complete.
|
||||||
|
|
||||||
|
New streaming sessions will be created on-demand when needed.
|
||||||
|
This method exists for API compatibility with HITL callers.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def print_panel(
|
def print_panel(
|
||||||
self, content: Text, title: str, style: str = "blue", is_flow: bool = False
|
self, content: Text, title: str, style: str = "blue", is_flow: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -7,22 +7,19 @@ from crewai.events.event_listener import event_listener
|
|||||||
class TestFlowHumanInputIntegration:
|
class TestFlowHumanInputIntegration:
|
||||||
"""Test integration between Flow execution and human input functionality."""
|
"""Test integration between Flow execution and human input functionality."""
|
||||||
|
|
||||||
def test_console_formatter_pause_resume_methods(self):
|
def test_console_formatter_pause_resume_methods_exist(self):
|
||||||
"""Test that ConsoleFormatter pause/resume methods work correctly."""
|
"""Test that ConsoleFormatter pause/resume methods exist and are callable."""
|
||||||
formatter = event_listener.formatter
|
formatter = event_listener.formatter
|
||||||
|
|
||||||
original_paused_state = formatter._live_paused
|
# Methods should exist and be callable
|
||||||
|
assert hasattr(formatter, "pause_live_updates")
|
||||||
try:
|
assert hasattr(formatter, "resume_live_updates")
|
||||||
formatter._live_paused = False
|
assert callable(formatter.pause_live_updates)
|
||||||
|
assert callable(formatter.resume_live_updates)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
formatter.pause_live_updates()
|
formatter.pause_live_updates()
|
||||||
assert formatter._live_paused
|
|
||||||
|
|
||||||
formatter.resume_live_updates()
|
formatter.resume_live_updates()
|
||||||
assert not formatter._live_paused
|
|
||||||
finally:
|
|
||||||
formatter._live_paused = original_paused_state
|
|
||||||
|
|
||||||
@patch("builtins.input", return_value="")
|
@patch("builtins.input", return_value="")
|
||||||
def test_human_input_pauses_flow_updates(self, mock_input):
|
def test_human_input_pauses_flow_updates(self, mock_input):
|
||||||
@@ -38,11 +35,6 @@ class TestFlowHumanInputIntegration:
|
|||||||
|
|
||||||
formatter = event_listener.formatter
|
formatter = event_listener.formatter
|
||||||
|
|
||||||
original_paused_state = formatter._live_paused
|
|
||||||
|
|
||||||
try:
|
|
||||||
formatter._live_paused = False
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(formatter, "pause_live_updates") as mock_pause,
|
patch.object(formatter, "pause_live_updates") as mock_pause,
|
||||||
patch.object(formatter, "resume_live_updates") as mock_resume,
|
patch.object(formatter, "resume_live_updates") as mock_resume,
|
||||||
@@ -53,8 +45,6 @@ class TestFlowHumanInputIntegration:
|
|||||||
mock_resume.assert_called_once()
|
mock_resume.assert_called_once()
|
||||||
mock_input.assert_called_once()
|
mock_input.assert_called_once()
|
||||||
assert result == ""
|
assert result == ""
|
||||||
finally:
|
|
||||||
formatter._live_paused = original_paused_state
|
|
||||||
|
|
||||||
@patch("builtins.input", side_effect=["feedback", ""])
|
@patch("builtins.input", side_effect=["feedback", ""])
|
||||||
def test_multiple_human_input_rounds(self, mock_input):
|
def test_multiple_human_input_rounds(self, mock_input):
|
||||||
@@ -70,9 +60,6 @@ class TestFlowHumanInputIntegration:
|
|||||||
|
|
||||||
formatter = event_listener.formatter
|
formatter = event_listener.formatter
|
||||||
|
|
||||||
original_paused_state = formatter._live_paused
|
|
||||||
|
|
||||||
try:
|
|
||||||
pause_calls = []
|
pause_calls = []
|
||||||
resume_calls = []
|
resume_calls = []
|
||||||
|
|
||||||
@@ -96,27 +83,23 @@ class TestFlowHumanInputIntegration:
|
|||||||
|
|
||||||
assert len(pause_calls) == 2
|
assert len(pause_calls) == 2
|
||||||
assert len(resume_calls) == 2
|
assert len(resume_calls) == 2
|
||||||
finally:
|
|
||||||
formatter._live_paused = original_paused_state
|
|
||||||
|
|
||||||
def test_pause_resume_with_no_live_session(self):
|
def test_pause_resume_with_no_live_session(self):
|
||||||
"""Test pause/resume methods handle case when no Live session exists."""
|
"""Test pause/resume methods handle case when no Live session exists."""
|
||||||
formatter = event_listener.formatter
|
formatter = event_listener.formatter
|
||||||
|
|
||||||
original_live = formatter._live
|
original_streaming_live = formatter._streaming_live
|
||||||
original_paused_state = formatter._live_paused
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
formatter._live = None
|
formatter._streaming_live = None
|
||||||
formatter._live_paused = False
|
|
||||||
|
|
||||||
|
# Should not raise when no session exists
|
||||||
formatter.pause_live_updates()
|
formatter.pause_live_updates()
|
||||||
formatter.resume_live_updates()
|
formatter.resume_live_updates()
|
||||||
|
|
||||||
assert not formatter._live_paused
|
assert formatter._streaming_live is None
|
||||||
finally:
|
finally:
|
||||||
formatter._live = original_live
|
formatter._streaming_live = original_streaming_live
|
||||||
formatter._live_paused = original_paused_state
|
|
||||||
|
|
||||||
def test_pause_resume_exception_handling(self):
|
def test_pause_resume_exception_handling(self):
|
||||||
"""Test that resume is called even if exception occurs during human input."""
|
"""Test that resume is called even if exception occurs during human input."""
|
||||||
@@ -131,9 +114,6 @@ class TestFlowHumanInputIntegration:
|
|||||||
|
|
||||||
formatter = event_listener.formatter
|
formatter = event_listener.formatter
|
||||||
|
|
||||||
original_paused_state = formatter._live_paused
|
|
||||||
|
|
||||||
try:
|
|
||||||
with (
|
with (
|
||||||
patch.object(formatter, "pause_live_updates") as mock_pause,
|
patch.object(formatter, "pause_live_updates") as mock_pause,
|
||||||
patch.object(formatter, "resume_live_updates") as mock_resume,
|
patch.object(formatter, "resume_live_updates") as mock_resume,
|
||||||
@@ -146,8 +126,6 @@ class TestFlowHumanInputIntegration:
|
|||||||
|
|
||||||
mock_pause.assert_called_once()
|
mock_pause.assert_called_once()
|
||||||
mock_resume.assert_called_once()
|
mock_resume.assert_called_once()
|
||||||
finally:
|
|
||||||
formatter._live_paused = original_paused_state
|
|
||||||
|
|
||||||
def test_training_mode_human_input(self):
|
def test_training_mode_human_input(self):
|
||||||
"""Test human input in training mode."""
|
"""Test human input in training mode."""
|
||||||
@@ -162,12 +140,10 @@ class TestFlowHumanInputIntegration:
|
|||||||
|
|
||||||
formatter = event_listener.formatter
|
formatter = event_listener.formatter
|
||||||
|
|
||||||
original_paused_state = formatter._live_paused
|
|
||||||
|
|
||||||
try:
|
|
||||||
with (
|
with (
|
||||||
patch.object(formatter, "pause_live_updates") as mock_pause,
|
patch.object(formatter, "pause_live_updates") as mock_pause,
|
||||||
patch.object(formatter, "resume_live_updates") as mock_resume,
|
patch.object(formatter, "resume_live_updates") as mock_resume,
|
||||||
|
patch.object(formatter.console, "print") as mock_console_print,
|
||||||
patch("builtins.input", return_value="training feedback"),
|
patch("builtins.input", return_value="training feedback"),
|
||||||
):
|
):
|
||||||
result = executor._ask_human_input("Test result")
|
result = executor._ask_human_input("Test result")
|
||||||
@@ -176,14 +152,13 @@ class TestFlowHumanInputIntegration:
|
|||||||
mock_resume.assert_called_once()
|
mock_resume.assert_called_once()
|
||||||
assert result == "training feedback"
|
assert result == "training feedback"
|
||||||
|
|
||||||
executor._printer.print.assert_called()
|
# Verify the training panel was printed via formatter's console
|
||||||
call_args = [
|
mock_console_print.assert_called()
|
||||||
call[1]["content"]
|
# Check that a Panel with training title was printed
|
||||||
for call in executor._printer.print.call_args_list
|
call_args = mock_console_print.call_args_list
|
||||||
]
|
training_panel_found = any(
|
||||||
training_prompt_found = any(
|
hasattr(call[0][0], "title") and "Training" in str(call[0][0].title)
|
||||||
"TRAINING MODE" in content for content in call_args
|
for call in call_args
|
||||||
|
if call[0]
|
||||||
)
|
)
|
||||||
assert training_prompt_found
|
assert training_panel_found
|
||||||
finally:
|
|
||||||
formatter._live_paused = original_paused_state
|
|
||||||
|
|||||||
@@ -1,116 +1,107 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
from rich.tree import Tree
|
|
||||||
from rich.live import Live
|
from rich.live import Live
|
||||||
from crewai.events.utils.console_formatter import ConsoleFormatter
|
from crewai.events.utils.console_formatter import ConsoleFormatter
|
||||||
|
|
||||||
|
|
||||||
class TestConsoleFormatterPauseResume:
|
class TestConsoleFormatterPauseResume:
|
||||||
"""Test ConsoleFormatter pause/resume functionality."""
|
"""Test ConsoleFormatter pause/resume functionality for HITL features."""
|
||||||
|
|
||||||
def test_pause_live_updates_with_active_session(self):
|
def test_pause_stops_active_streaming_session(self):
|
||||||
"""Test pausing when Live session is active."""
|
"""Test pausing stops an active streaming Live session."""
|
||||||
formatter = ConsoleFormatter()
|
formatter = ConsoleFormatter()
|
||||||
|
|
||||||
mock_live = MagicMock(spec=Live)
|
mock_live = MagicMock(spec=Live)
|
||||||
formatter._live = mock_live
|
formatter._streaming_live = mock_live
|
||||||
formatter._live_paused = False
|
|
||||||
|
|
||||||
formatter.pause_live_updates()
|
formatter.pause_live_updates()
|
||||||
|
|
||||||
mock_live.stop.assert_called_once()
|
mock_live.stop.assert_called_once()
|
||||||
assert formatter._live_paused
|
assert formatter._streaming_live is None
|
||||||
|
|
||||||
def test_pause_live_updates_when_already_paused(self):
|
def test_pause_is_safe_when_no_session(self):
|
||||||
"""Test pausing when already paused does nothing."""
|
"""Test pausing when no streaming session exists doesn't error."""
|
||||||
|
formatter = ConsoleFormatter()
|
||||||
|
formatter._streaming_live = None
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
formatter.pause_live_updates()
|
||||||
|
|
||||||
|
assert formatter._streaming_live is None
|
||||||
|
|
||||||
|
def test_multiple_pauses_are_safe(self):
|
||||||
|
"""Test calling pause multiple times is safe."""
|
||||||
formatter = ConsoleFormatter()
|
formatter = ConsoleFormatter()
|
||||||
|
|
||||||
mock_live = MagicMock(spec=Live)
|
mock_live = MagicMock(spec=Live)
|
||||||
formatter._live = mock_live
|
formatter._streaming_live = mock_live
|
||||||
formatter._live_paused = True
|
|
||||||
|
|
||||||
formatter.pause_live_updates()
|
formatter.pause_live_updates()
|
||||||
|
mock_live.stop.assert_called_once()
|
||||||
|
assert formatter._streaming_live is None
|
||||||
|
|
||||||
mock_live.stop.assert_not_called()
|
# Second pause should not error (no session to stop)
|
||||||
assert formatter._live_paused
|
|
||||||
|
|
||||||
def test_pause_live_updates_with_no_session(self):
|
|
||||||
"""Test pausing when no Live session exists."""
|
|
||||||
formatter = ConsoleFormatter()
|
|
||||||
|
|
||||||
formatter._live = None
|
|
||||||
formatter._live_paused = False
|
|
||||||
|
|
||||||
formatter.pause_live_updates()
|
formatter.pause_live_updates()
|
||||||
|
|
||||||
assert formatter._live_paused
|
def test_resume_is_safe(self):
|
||||||
|
"""Test resume method exists and doesn't error."""
|
||||||
def test_resume_live_updates_when_paused(self):
|
|
||||||
"""Test resuming when paused."""
|
|
||||||
formatter = ConsoleFormatter()
|
formatter = ConsoleFormatter()
|
||||||
|
|
||||||
formatter._live_paused = True
|
# Should not raise
|
||||||
|
|
||||||
formatter.resume_live_updates()
|
formatter.resume_live_updates()
|
||||||
|
|
||||||
assert not formatter._live_paused
|
def test_streaming_after_pause_resume_creates_new_session(self):
|
||||||
|
"""Test that streaming after pause/resume creates new Live session."""
|
||||||
def test_resume_live_updates_when_not_paused(self):
|
|
||||||
"""Test resuming when not paused does nothing."""
|
|
||||||
formatter = ConsoleFormatter()
|
formatter = ConsoleFormatter()
|
||||||
|
formatter.verbose = True
|
||||||
|
|
||||||
formatter._live_paused = False
|
# Simulate having an active session
|
||||||
|
mock_live = MagicMock(spec=Live)
|
||||||
|
formatter._streaming_live = mock_live
|
||||||
|
|
||||||
|
# Pause stops the session
|
||||||
|
formatter.pause_live_updates()
|
||||||
|
assert formatter._streaming_live is None
|
||||||
|
|
||||||
|
# Resume (no-op, sessions created on demand)
|
||||||
formatter.resume_live_updates()
|
formatter.resume_live_updates()
|
||||||
|
|
||||||
assert not formatter._live_paused
|
# After resume, streaming should be able to start a new session
|
||||||
|
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
|
||||||
|
mock_live_instance = MagicMock()
|
||||||
|
mock_live_class.return_value = mock_live_instance
|
||||||
|
|
||||||
def test_print_after_resume_restarts_live_session(self):
|
# Simulate streaming chunk (this creates a new Live session)
|
||||||
"""Test that printing a Tree after resume creates new Live session."""
|
formatter.handle_llm_stream_chunk("test chunk", call_type=None)
|
||||||
|
|
||||||
|
mock_live_class.assert_called_once()
|
||||||
|
mock_live_instance.start.assert_called_once()
|
||||||
|
assert formatter._streaming_live == mock_live_instance
|
||||||
|
|
||||||
|
def test_pause_resume_cycle_with_streaming(self):
|
||||||
|
"""Test full pause/resume cycle during streaming."""
|
||||||
formatter = ConsoleFormatter()
|
formatter = ConsoleFormatter()
|
||||||
|
formatter.verbose = True
|
||||||
formatter._live_paused = True
|
|
||||||
formatter._live = None
|
|
||||||
|
|
||||||
formatter.resume_live_updates()
|
|
||||||
assert not formatter._live_paused
|
|
||||||
|
|
||||||
tree = Tree("Test")
|
|
||||||
|
|
||||||
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
|
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
|
||||||
mock_live_instance = MagicMock()
|
mock_live_instance = MagicMock()
|
||||||
mock_live_class.return_value = mock_live_instance
|
mock_live_class.return_value = mock_live_instance
|
||||||
|
|
||||||
formatter.print(tree)
|
# Start streaming
|
||||||
|
formatter.handle_llm_stream_chunk("chunk 1", call_type=None)
|
||||||
mock_live_class.assert_called_once()
|
assert formatter._streaming_live == mock_live_instance
|
||||||
mock_live_instance.start.assert_called_once()
|
|
||||||
assert formatter._live == mock_live_instance
|
|
||||||
|
|
||||||
def test_multiple_pause_resume_cycles(self):
|
|
||||||
"""Test multiple pause/resume cycles work correctly."""
|
|
||||||
formatter = ConsoleFormatter()
|
|
||||||
|
|
||||||
mock_live = MagicMock(spec=Live)
|
|
||||||
formatter._live = mock_live
|
|
||||||
formatter._live_paused = False
|
|
||||||
|
|
||||||
|
# Pause should stop the session
|
||||||
formatter.pause_live_updates()
|
formatter.pause_live_updates()
|
||||||
assert formatter._live_paused
|
mock_live_instance.stop.assert_called_once()
|
||||||
mock_live.stop.assert_called_once()
|
assert formatter._streaming_live is None
|
||||||
assert formatter._live is None # Live session should be cleared
|
|
||||||
|
|
||||||
|
# Resume (no-op)
|
||||||
formatter.resume_live_updates()
|
formatter.resume_live_updates()
|
||||||
assert not formatter._live_paused
|
|
||||||
|
|
||||||
formatter.pause_live_updates()
|
# Create a new mock for the next session
|
||||||
assert formatter._live_paused
|
mock_live_instance_2 = MagicMock()
|
||||||
|
mock_live_class.return_value = mock_live_instance_2
|
||||||
|
|
||||||
formatter.resume_live_updates()
|
# Streaming again creates new session
|
||||||
assert not formatter._live_paused
|
formatter.handle_llm_stream_chunk("chunk 2", call_type=None)
|
||||||
|
assert formatter._streaming_live == mock_live_instance_2
|
||||||
def test_pause_resume_state_initialization(self):
|
|
||||||
"""Test that _live_paused is properly initialized."""
|
|
||||||
formatter = ConsoleFormatter()
|
|
||||||
|
|
||||||
assert hasattr(formatter, "_live_paused")
|
|
||||||
assert not formatter._live_paused
|
|
||||||
|
|||||||
Reference in New Issue
Block a user