"""Tests for Flow.ask() user input method. This module tests the ask() method on Flow, including basic usage, timeout behavior, provider resolution, event emission, auto-checkpoint durability, input history tracking, and integration with flow machinery. """ from __future__ import annotations import time from datetime import datetime from typing import Any from unittest.mock import MagicMock, patch from crewai.flow import Flow, flow_config, listen, start from crewai.flow.async_feedback.providers import ConsoleProvider from crewai.flow.flow import FlowState from crewai.flow.input_provider import InputProvider, InputResponse # ── Test helpers ───────────────────────────────────────────────── class MockInputProvider: """Mock input provider that returns pre-configured responses.""" def __init__(self, responses: list[str | None]) -> None: self.responses = responses self._call_count = 0 self.messages: list[str] = [] self.received_metadata: list[dict[str, Any] | None] = [] def request_input( self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None ) -> str | None: self.messages.append(message) self.received_metadata.append(metadata) if self._call_count >= len(self.responses): return None response = self.responses[self._call_count] self._call_count += 1 return response class SlowMockProvider: """Mock provider that delays before returning, for timeout tests.""" def __init__(self, delay: float, response: str = "delayed") -> None: self.delay = delay self.response = response def request_input( self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None ) -> str | None: time.sleep(self.delay) return self.response # ── Basic Functionality ────────────────────────────────────────── class TestAskBasic: """Tests for basic ask() functionality.""" def test_ask_returns_user_input(self) -> None: """ask() returns the string from the input provider.""" class TestFlow(Flow): input_provider = MockInputProvider(["hello"]) @start() def my_method(self): return self.ask("Say something:") flow = TestFlow() result = flow.kickoff() assert result == "hello" def test_ask_in_async_method(self) -> None: """ask() works inside an async flow method.""" class TestFlow(Flow): input_provider = MockInputProvider(["async hello"]) @start() async def my_method(self): return self.ask("Say something:") flow = TestFlow() result = flow.kickoff() assert result == "async hello" def test_ask_in_start_method(self) -> None: """ask() works inside a @start() method, flow completes normally.""" execution_log: list[str] = [] class TestFlow(Flow): input_provider = MockInputProvider(["AI"]) @start() def gather(self): topic = self.ask("Topic?") execution_log.append(f"got:{topic}") return topic flow = TestFlow() result = flow.kickoff() assert result == "AI" assert execution_log == ["got:AI"] def test_ask_in_listen_method(self) -> None: """ask() works inside a @listen() method.""" class TestFlow(Flow): input_provider = MockInputProvider(["detailed"]) @start() def step1(self): return "topic" @listen("step1") def step2(self): depth = self.ask("How deep?") return f"researching at {depth} level" flow = TestFlow() result = flow.kickoff() assert result == "researching at detailed level" def test_ask_multiple_calls(self) -> None: """Multiple ask() calls in one method return correct values in order.""" class TestFlow(Flow): input_provider = MockInputProvider(["AI", "detailed", "english"]) @start() def gather(self): topic = self.ask("Topic?") depth = self.ask("Depth?") lang = self.ask("Language?") return {"topic": topic, "depth": depth, "lang": lang} flow = TestFlow() result = flow.kickoff() assert result == {"topic": "AI", "depth": "detailed", "lang": "english"} def test_ask_conditional(self) -> None: """ask() called conditionally based on previous answer.""" class TestFlow(Flow): input_provider = MockInputProvider(["AI", "LLMs"]) @start() def gather(self): topic = self.ask("Topic?") if topic == "AI": focus = self.ask("Specific area?") else: focus = "general" return {"topic": topic, "focus": focus} flow = TestFlow() result = flow.kickoff() assert result == {"topic": "AI", "focus": "LLMs"} def test_ask_returns_empty_string_on_enter(self) -> None: """Empty string means user pressed Enter (intentional empty input).""" class TestFlow(Flow): input_provider = MockInputProvider([""]) @start() def my_method(self): result = self.ask("Optional input:") return result flow = TestFlow() result = flow.kickoff() assert result == "" assert result is not None # Explicitly not None # ── Timeout ────────────────────────────────────────────────────── class TestAskTimeout: """Tests for timeout behavior.""" def test_ask_timeout_returns_none(self) -> None: """ask() returns None when timeout expires.""" class TestFlow(Flow): input_provider = SlowMockProvider(delay=5.0) @start() def my_method(self): return self.ask("Question?", timeout=0.1) flow = TestFlow() result = flow.kickoff() assert result is None def test_ask_timeout_in_async_method(self) -> None: """ask() timeout works inside an async flow method.""" class TestFlow(Flow): input_provider = SlowMockProvider(delay=5.0) @start() async def my_method(self): return self.ask("Question?", timeout=0.1) flow = TestFlow() result = flow.kickoff() assert result is None def test_ask_loop_with_timeout_termination(self) -> None: """while (msg := ask(...)) is not None pattern terminates on timeout.""" messages_received: list[str] = [] class TestFlow(Flow): input_provider = MockInputProvider(["hello", "world", None]) @start() def chat(self): while (msg := self.ask("You:")) is not None: messages_received.append(msg) return len(messages_received) flow = TestFlow() result = flow.kickoff() assert result == 2 assert messages_received == ["hello", "world"] def test_ask_no_timeout_waits_indefinitely(self) -> None: """ask() with no timeout blocks until provider returns.""" class TestFlow(Flow): input_provider = MockInputProvider(["answer"]) @start() def my_method(self): return self.ask("Question?") # no timeout flow = TestFlow() result = flow.kickoff() assert result == "answer" # ── Provider Resolution ────────────────────────────────────────── class TestProviderResolution: """Tests for provider resolution priority chain.""" def test_ask_uses_flow_level_provider(self) -> None: """Per-flow input_provider is used when set.""" provider = MockInputProvider(["from flow"]) class TestFlow(Flow): input_provider = provider @start() def my_method(self): return self.ask("Q?") flow = TestFlow() flow.kickoff() assert provider.messages == ["Q?"] def test_ask_uses_global_config_provider(self) -> None: """flow_config.input_provider is used as fallback.""" provider = MockInputProvider(["from config"]) original = flow_config.input_provider try: flow_config.input_provider = provider class TestFlow(Flow): @start() def my_method(self): return self.ask("Q?") flow = TestFlow() result = flow.kickoff() assert result == "from config" assert provider.messages == ["Q?"] finally: flow_config.input_provider = original def test_ask_defaults_to_console_provider(self) -> None: """When no provider configured, ConsoleProvider is used.""" original = flow_config.input_provider try: flow_config.input_provider = None class TestFlow(Flow): # No input_provider set @start() def my_method(self): return self.ask("Q?") flow = TestFlow() resolved = flow._resolve_input_provider() assert isinstance(resolved, ConsoleProvider) finally: flow_config.input_provider = original def test_flow_provider_overrides_global(self) -> None: """Per-flow provider takes precedence over global config.""" flow_provider = MockInputProvider(["from flow"]) global_provider = MockInputProvider(["from global"]) original = flow_config.input_provider try: flow_config.input_provider = global_provider class TestFlow(Flow): input_provider = flow_provider @start() def my_method(self): return self.ask("Q?") flow = TestFlow() result = flow.kickoff() assert result == "from flow" assert flow_provider.messages == ["Q?"] assert global_provider.messages == [] # not called finally: flow_config.input_provider = original # ── Events ─────────────────────────────────────────────────────── class TestAskEvents: """Tests for event emission during ask().""" def test_ask_emits_input_requested_event(self) -> None: """FlowInputRequestedEvent is emitted when ask() is called.""" from crewai.events.event_bus import crewai_event_bus from crewai.events.types.flow_events import FlowInputRequestedEvent events_captured: list[FlowInputRequestedEvent] = [] class TestFlow(Flow): input_provider = MockInputProvider(["answer"]) @start() def my_method(self): return self.ask("What topic?") flow = TestFlow() original_emit = crewai_event_bus.emit def capture_emit(source: Any, event: Any) -> Any: if isinstance(event, FlowInputRequestedEvent): events_captured.append(event) return original_emit(source, event) with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): flow.kickoff() assert len(events_captured) == 1 assert events_captured[0].message == "What topic?" assert events_captured[0].type == "flow_input_requested" def test_ask_emits_input_received_event(self) -> None: """FlowInputReceivedEvent is emitted after input is received.""" from crewai.events.event_bus import crewai_event_bus from crewai.events.types.flow_events import FlowInputReceivedEvent events_captured: list[FlowInputReceivedEvent] = [] class TestFlow(Flow): input_provider = MockInputProvider(["my answer"]) @start() def my_method(self): return self.ask("Question?") flow = TestFlow() original_emit = crewai_event_bus.emit def capture_emit(source: Any, event: Any) -> Any: if isinstance(event, FlowInputReceivedEvent): events_captured.append(event) return original_emit(source, event) with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): flow.kickoff() assert len(events_captured) == 1 assert events_captured[0].message == "Question?" assert events_captured[0].response == "my answer" assert events_captured[0].type == "flow_input_received" def test_ask_timeout_emits_received_with_none(self) -> None: """FlowInputReceivedEvent has response=None on timeout.""" from crewai.events.event_bus import crewai_event_bus from crewai.events.types.flow_events import FlowInputReceivedEvent events_captured: list[FlowInputReceivedEvent] = [] class TestFlow(Flow): input_provider = SlowMockProvider(delay=5.0) @start() def my_method(self): return self.ask("Question?", timeout=0.1) flow = TestFlow() original_emit = crewai_event_bus.emit def capture_emit(source: Any, event: Any) -> Any: if isinstance(event, FlowInputReceivedEvent): events_captured.append(event) return original_emit(source, event) with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): flow.kickoff() assert len(events_captured) == 1 assert events_captured[0].response is None # ── Auto-checkpoint (Durability) ───────────────────────────────── class TestAskCheckpoint: """Tests for auto-checkpoint durability before ask() waits.""" def test_ask_checkpoints_state_before_waiting(self) -> None: """State is saved to persistence before waiting for input.""" mock_persistence = MagicMock() mock_persistence.load_state.return_value = None class TestFlow(Flow): input_provider = MockInputProvider(["answer"]) @start() def my_method(self): self.state["important"] = "data" return self.ask("Question?") flow = TestFlow(persistence=mock_persistence) flow.kickoff() # Find the _ask_checkpoint call among save_state calls checkpoint_calls = [ c for c in mock_persistence.save_state.call_args_list if c.kwargs.get("method_name") == "_ask_checkpoint" or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint") ] assert len(checkpoint_calls) >= 1 def test_ask_no_checkpoint_without_persistence(self) -> None: """No error when persistence is not configured.""" class TestFlow(Flow): input_provider = MockInputProvider(["answer"]) @start() def my_method(self): return self.ask("Question?") flow = TestFlow() # No persistence result = flow.kickoff() assert result == "answer" # Works fine without persistence def test_state_recoverable_after_checkpoint(self) -> None: """State set before ask() is checkpointed and recoverable. The auto-checkpoint happens *before* the provider is called, so state values set prior to ask() are persisted. This means if the server crashes while waiting for input, previously gathered data is safe. """ mock_persistence = MagicMock() mock_persistence.load_state.return_value = None class GatherFlow(Flow): input_provider = MockInputProvider(["AI", "detailed"]) @start() def gather(self): # First ask: nothing in state yet topic = self.ask("Topic?") self.state["topic"] = topic # Second ask: state now has topic, checkpoint saves it depth = self.ask("Depth?") self.state["depth"] = depth return {"topic": topic, "depth": depth} flow = GatherFlow(persistence=mock_persistence) result = flow.kickoff() assert result == {"topic": "AI", "depth": "detailed"} # Find the checkpoint calls checkpoint_calls = [ c for c in mock_persistence.save_state.call_args_list if c.kwargs.get("method_name") == "_ask_checkpoint" or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint") ] assert len(checkpoint_calls) == 2 # The second checkpoint (before asking "Depth?") should have topic second_checkpoint = checkpoint_calls[1] # state_data is the third positional arg or keyword arg if second_checkpoint.kwargs.get("state_data"): state_data = second_checkpoint.kwargs["state_data"] else: state_data = second_checkpoint.args[2] assert state_data.get("topic") == "AI" # ── Input History ──────────────────────────────────────────────── class TestInputHistory: """Tests for _input_history tracking.""" def test_input_history_accumulated(self) -> None: """_input_history tracks all ask/response pairs.""" class TestFlow(Flow): input_provider = MockInputProvider(["AI", "detailed"]) @start() def gather(self): self.ask("Topic?") self.ask("Depth?") return "done" flow = TestFlow() flow.kickoff() assert len(flow._input_history) == 2 assert flow._input_history[0]["message"] == "Topic?" assert flow._input_history[0]["response"] == "AI" assert flow._input_history[1]["message"] == "Depth?" assert flow._input_history[1]["response"] == "detailed" def test_input_history_includes_method_name(self) -> None: """Input history records which method called ask().""" class TestFlow(Flow): input_provider = MockInputProvider(["AI"]) @start() def gather_info(self): self.ask("Topic?") return "done" flow = TestFlow() flow.kickoff() assert len(flow._input_history) == 1 assert flow._input_history[0]["method_name"] == "gather_info" def test_input_history_includes_timestamp(self) -> None: """Input history records timestamps.""" class TestFlow(Flow): input_provider = MockInputProvider(["AI"]) @start() def my_method(self): self.ask("Topic?") return "done" flow = TestFlow() before = datetime.now() flow.kickoff() after = datetime.now() assert len(flow._input_history) == 1 ts = flow._input_history[0]["timestamp"] assert isinstance(ts, datetime) assert before <= ts <= after def test_input_history_records_none_on_timeout(self) -> None: """Input history records None response on timeout.""" class TestFlow(Flow): input_provider = SlowMockProvider(delay=5.0) @start() def my_method(self): self.ask("Question?", timeout=0.1) return "done" flow = TestFlow() flow.kickoff() assert len(flow._input_history) == 1 assert flow._input_history[0]["response"] is None # ── Integration ────────────────────────────────────────────────── class TestAskIntegration: """Integration tests for ask() with other flow features.""" def test_ask_works_with_listen_chain(self) -> None: """ask() in a start method, result flows to listener.""" execution_log: list[str] = [] class TestFlow(Flow): input_provider = MockInputProvider(["AI agents"]) @start() def gather(self): topic = self.ask("Topic?") execution_log.append(f"gathered:{topic}") return topic @listen("gather") def process(self): execution_log.append("processing") return "processed" flow = TestFlow() flow.kickoff() assert "gathered:AI agents" in execution_log assert "processing" in execution_log def test_ask_with_structured_state(self) -> None: """ask() works with Pydantic-based flow state.""" class ResearchState(FlowState): topic: str = "" depth: str = "" class TestFlow(Flow[ResearchState]): initial_state = ResearchState input_provider = MockInputProvider(["AI", "detailed"]) @start() def gather(self): self.state.topic = self.ask("Topic?") self.state.depth = self.ask("Depth?") return {"topic": self.state.topic, "depth": self.state.depth} flow = TestFlow() result = flow.kickoff() assert result == {"topic": "AI", "depth": "detailed"} assert flow.state.topic == "AI" assert flow.state.depth == "detailed" def test_ask_in_async_method_with_listen_chain(self) -> None: """ask() in an async start method, result flows to listener.""" execution_log: list[str] = [] class TestFlow(Flow): input_provider = MockInputProvider(["async topic"]) @start() async def gather(self): topic = self.ask("Topic?") execution_log.append(f"gathered:{topic}") return topic @listen("gather") def process(self): execution_log.append("processing") return "processed" flow = TestFlow() flow.kickoff() assert "gathered:async topic" in execution_log assert "processing" in execution_log def test_ask_with_state_persistence_recovery(self) -> None: """Ask checkpoints state so previously gathered values survive.""" mock_persistence = MagicMock() mock_persistence.load_state.return_value = None class RecoverableFlow(Flow): input_provider = MockInputProvider(["AI", "detailed"]) @start() def gather(self): if not self.state.get("topic"): self.state["topic"] = self.ask("Topic?") if not self.state.get("depth"): self.state["depth"] = self.ask("Depth?") return { "topic": self.state["topic"], "depth": self.state["depth"], } flow = RecoverableFlow(persistence=mock_persistence) result = flow.kickoff() assert result["topic"] == "AI" assert result["depth"] == "detailed" # Verify checkpoints were made checkpoint_calls = [ c for c in mock_persistence.save_state.call_args_list if c.kwargs.get("method_name") == "_ask_checkpoint" or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint") ] # Two ask() calls = two checkpoints assert len(checkpoint_calls) == 2 def test_ask_and_human_feedback_coexist(self) -> None: """ask() and @human_feedback can be used in the same flow.""" from crewai.flow import human_feedback class TestFlow(Flow): input_provider = MockInputProvider(["AI"]) @start() def gather(self): topic = self.ask("Topic?") return topic @listen("gather") @human_feedback(message="Review this topic:") def review(self): return f"Researching: {self.state.get('_last_topic', 'unknown')}" flow = TestFlow() with patch.object(flow, "_request_human_feedback", return_value="looks good"): flow.kickoff() # Flow completed with both ask and human_feedback assert flow.last_human_feedback is not None def test_ask_preserves_flow_lifecycle(self) -> None: """Flow events (started, finished) still fire normally with ask().""" from crewai.events.event_bus import crewai_event_bus from crewai.events.types.flow_events import ( FlowFinishedEvent, FlowStartedEvent, ) events_seen: list[str] = [] class TestFlow(Flow): input_provider = MockInputProvider(["answer"]) @start() def my_method(self): return self.ask("Q?") flow = TestFlow() original_emit = crewai_event_bus.emit def capture_emit(source: Any, event: Any) -> Any: if isinstance(event, FlowStartedEvent): events_seen.append("started") elif isinstance(event, FlowFinishedEvent): events_seen.append("finished") return original_emit(source, event) with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): flow.kickoff() assert "started" in events_seen assert "finished" in events_seen # ── Console Provider ───────────────────────────────────────────── class TestConsoleProviderInput: """Tests for ConsoleProvider.request_input() (used by Flow.ask()).""" def test_console_provider_pauses_live_updates(self) -> None: """ConsoleProvider pauses and resumes formatter live updates.""" from crewai.events.event_listener import event_listener mock_formatter = MagicMock() mock_formatter.console = MagicMock() provider = ConsoleProvider(verbose=True) with ( patch.object(event_listener, "formatter", mock_formatter), patch("builtins.input", return_value="test input"), ): result = provider.request_input("Question?", MagicMock()) mock_formatter.pause_live_updates.assert_called_once() mock_formatter.resume_live_updates.assert_called_once() assert result == "test input" def test_console_provider_displays_message(self) -> None: """ConsoleProvider displays the message with Rich console.""" from crewai.events.event_listener import event_listener mock_formatter = MagicMock() mock_console = MagicMock() mock_formatter.console = mock_console provider = ConsoleProvider(verbose=True) with ( patch.object(event_listener, "formatter", mock_formatter), patch("builtins.input", return_value="answer"), ): provider.request_input("What topic?", MagicMock()) # Verify the message was printed print_calls = [str(c) for c in mock_console.print.call_args_list] assert any("What topic?" in c for c in print_calls) def test_console_provider_non_verbose(self) -> None: """ConsoleProvider in non-verbose mode uses plain input.""" from crewai.events.event_listener import event_listener mock_formatter = MagicMock() mock_formatter.console = MagicMock() provider = ConsoleProvider(verbose=False) with ( patch.object(event_listener, "formatter", mock_formatter), patch("builtins.input", return_value="plain answer") as mock_input, ): result = provider.request_input("Q?", MagicMock()) assert result == "plain answer" mock_input.assert_called_once_with("Q? ") def test_console_provider_strips_response(self) -> None: """ConsoleProvider strips whitespace from response.""" from crewai.events.event_listener import event_listener mock_formatter = MagicMock() mock_formatter.console = MagicMock() provider = ConsoleProvider(verbose=False) with ( patch.object(event_listener, "formatter", mock_formatter), patch("builtins.input", return_value=" spaced answer "), ): result = provider.request_input("Q?", MagicMock()) assert result == "spaced answer" def test_console_provider_implements_protocol(self) -> None: """ConsoleProvider satisfies the InputProvider protocol.""" provider = ConsoleProvider() assert isinstance(provider, InputProvider) # ── InputProvider Protocol ─────────────────────────────────────── class TestInputProviderProtocol: """Tests for the InputProvider protocol.""" def test_custom_provider_satisfies_protocol(self) -> None: """A class with request_input satisfies the InputProvider protocol.""" class MyProvider: def request_input(self, message: str, flow: Flow[Any]) -> str | None: return "custom" provider = MyProvider() assert isinstance(provider, InputProvider) def test_mock_provider_satisfies_protocol(self) -> None: """MockInputProvider satisfies the InputProvider protocol.""" provider = MockInputProvider(["test"]) assert isinstance(provider, InputProvider) # ── Error Handling ─────────────────────────────────────────────── class TestAskErrorHandling: """Tests for error handling in ask().""" def test_ask_returns_none_on_provider_error(self) -> None: """ask() returns None if provider raises an exception.""" class FailingProvider: def request_input(self, message: str, flow: Flow[Any]) -> str | None: raise RuntimeError("Provider failed") class TestFlow(Flow): input_provider = FailingProvider() @start() def my_method(self): return self.ask("Question?") flow = TestFlow() result = flow.kickoff() assert result is None def test_ask_in_async_method_returns_none_on_provider_error(self) -> None: """ask() returns None if provider raises in an async method.""" class FailingProvider: def request_input(self, message: str, flow: Flow[Any]) -> str | None: raise RuntimeError("Provider failed") class TestFlow(Flow): input_provider = FailingProvider() @start() async def my_method(self): return self.ask("Question?") flow = TestFlow() result = flow.kickoff() assert result is None # ── Metadata ───────────────────────────────────────────────────── class TestAskMetadata: """Tests for bidirectional metadata support in ask().""" def test_ask_passes_metadata_to_provider(self) -> None: """Provider receives the metadata dict from ask().""" provider = MockInputProvider(["answer"]) class TestFlow(Flow): input_provider = provider @start() def my_method(self): return self.ask("Q?", metadata={"user_id": "u123"}) flow = TestFlow() flow.kickoff() assert provider.received_metadata == [{"user_id": "u123"}] def test_ask_metadata_none_by_default(self) -> None: """Provider receives None metadata when not provided.""" provider = MockInputProvider(["answer"]) class TestFlow(Flow): input_provider = provider @start() def my_method(self): return self.ask("Q?") flow = TestFlow() flow.kickoff() assert provider.received_metadata == [None] def test_ask_provider_returns_input_response(self) -> None: """Provider returns InputResponse with response metadata.""" class MetadataProvider: def request_input( self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None ) -> InputResponse: return InputResponse( text="the answer", metadata={"responded_by": "u456", "thread_id": "t789"}, ) class TestFlow(Flow): input_provider = MetadataProvider() @start() def my_method(self): return self.ask("Q?", metadata={"user_id": "u123"}) flow = TestFlow() result = flow.kickoff() # ask() still returns plain string assert result == "the answer" # History has both metadata dicts assert len(flow._input_history) == 1 entry = flow._input_history[0] assert entry["metadata"] == {"user_id": "u123"} assert entry["response_metadata"] == {"responded_by": "u456", "thread_id": "t789"} def test_ask_provider_returns_string_with_metadata_sent(self) -> None: """Provider returns plain string; history has metadata but no response_metadata.""" class TestFlow(Flow): input_provider = MockInputProvider(["answer"]) @start() def my_method(self): return self.ask("Q?", metadata={"channel": "#research"}) flow = TestFlow() flow.kickoff() entry = flow._input_history[0] assert entry["metadata"] == {"channel": "#research"} assert entry["response_metadata"] is None def test_ask_metadata_in_requested_event(self) -> None: """FlowInputRequestedEvent carries metadata.""" from crewai.events.event_bus import crewai_event_bus from crewai.events.types.flow_events import FlowInputRequestedEvent events_captured: list[FlowInputRequestedEvent] = [] class TestFlow(Flow): input_provider = MockInputProvider(["answer"]) @start() def my_method(self): return self.ask("Q?", metadata={"user_id": "u123"}) flow = TestFlow() original_emit = crewai_event_bus.emit def capture_emit(source: Any, event: Any) -> Any: if isinstance(event, FlowInputRequestedEvent): events_captured.append(event) return original_emit(source, event) with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): flow.kickoff() assert len(events_captured) == 1 assert events_captured[0].metadata == {"user_id": "u123"} def test_ask_metadata_in_received_event(self) -> None: """FlowInputReceivedEvent carries both metadata and response_metadata.""" from crewai.events.event_bus import crewai_event_bus from crewai.events.types.flow_events import FlowInputReceivedEvent events_captured: list[FlowInputReceivedEvent] = [] class MetadataProvider: def request_input( self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None ) -> InputResponse: return InputResponse(text="answer", metadata={"responded_by": "u456"}) class TestFlow(Flow): input_provider = MetadataProvider() @start() def my_method(self): return self.ask("Q?", metadata={"user_id": "u123"}) flow = TestFlow() original_emit = crewai_event_bus.emit def capture_emit(source: Any, event: Any) -> Any: if isinstance(event, FlowInputReceivedEvent): events_captured.append(event) return original_emit(source, event) with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): flow.kickoff() assert len(events_captured) == 1 assert events_captured[0].metadata == {"user_id": "u123"} assert events_captured[0].response_metadata == {"responded_by": "u456"} assert events_captured[0].response == "answer" def test_ask_input_response_with_none_text(self) -> None: """Provider returns InputResponse with text=None.""" class NoneTextProvider: def request_input( self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None ) -> InputResponse: return InputResponse(text=None, metadata={"reason": "user_declined"}) class TestFlow(Flow): input_provider = NoneTextProvider() @start() def my_method(self): return self.ask("Q?") flow = TestFlow() result = flow.kickoff() assert result is None entry = flow._input_history[0] assert entry["response"] is None assert entry["response_metadata"] == {"reason": "user_declined"} def test_ask_metadata_thread_safe(self) -> None: """Concurrent ask() calls with different metadata don't cross-contaminate.""" import threading call_log: list[dict[str, Any]] = [] log_lock = threading.Lock() class TrackingProvider: def request_input( self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None ) -> InputResponse: # Small delay to increase chance of interleaving time.sleep(0.05) with log_lock: call_log.append({"message": message, "metadata": metadata}) user = metadata.get("user", "unknown") if metadata else "unknown" return InputResponse( text=f"answer from {user}", metadata={"responded_by": user}, ) class TestFlow(Flow): input_provider = TrackingProvider() @start() def trigger(self): return "go" @listen("trigger") def listener_a(self): return self.ask("Question A?", metadata={"user": "alice"}) @listen("trigger") def listener_b(self): return self.ask("Question B?", metadata={"user": "bob"}) flow = TestFlow() flow.kickoff() # Both calls should have recorded their own metadata assert len(flow._input_history) == 2 alice_entry = next( (e for e in flow._input_history if e["metadata"] and e["metadata"].get("user") == "alice"), None, ) bob_entry = next( (e for e in flow._input_history if e["metadata"] and e["metadata"].get("user") == "bob"), None, ) assert alice_entry is not None assert alice_entry["response"] == "answer from alice" assert alice_entry["response_metadata"] == {"responded_by": "alice"} assert bob_entry is not None assert bob_entry["response"] == "answer from bob" assert bob_entry["response_metadata"] == {"responded_by": "bob"}