mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-07 19:48:13 +00:00
1153 lines
39 KiB
Python
1153 lines
39 KiB
Python
"""Tests for Flow.ask() user input method.
|
|
|
|
This module tests the ask() method on Flow, including basic usage,
|
|
timeout behavior, provider resolution, event emission, auto-checkpoint
|
|
durability, input history tracking, and integration with flow machinery.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from crewai.flow import Flow, flow_config, listen, start
|
|
from crewai.flow.async_feedback.providers import ConsoleProvider
|
|
from crewai.flow.flow import FlowState
|
|
from crewai.flow.input_provider import InputProvider, InputResponse
|
|
|
|
|
|
# ── Test helpers ─────────────────────────────────────────────────
|
|
|
|
|
|
class MockInputProvider:
|
|
"""Mock input provider that returns pre-configured responses."""
|
|
|
|
def __init__(self, responses: list[str | None]) -> None:
|
|
self.responses = responses
|
|
self._call_count = 0
|
|
self.messages: list[str] = []
|
|
self.received_metadata: list[dict[str, Any] | None] = []
|
|
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> str | None:
|
|
self.messages.append(message)
|
|
self.received_metadata.append(metadata)
|
|
if self._call_count >= len(self.responses):
|
|
return None
|
|
response = self.responses[self._call_count]
|
|
self._call_count += 1
|
|
return response
|
|
|
|
|
|
class SlowMockProvider:
|
|
"""Mock provider that delays before returning, for timeout tests."""
|
|
|
|
def __init__(self, delay: float, response: str = "delayed") -> None:
|
|
self.delay = delay
|
|
self.response = response
|
|
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> str | None:
|
|
time.sleep(self.delay)
|
|
return self.response
|
|
|
|
|
|
# ── Basic Functionality ──────────────────────────────────────────
|
|
|
|
|
|
class TestAskBasic:
|
|
"""Tests for basic ask() functionality."""
|
|
|
|
def test_ask_returns_user_input(self) -> None:
|
|
"""ask() returns the string from the input provider."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["hello"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Say something:")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "hello"
|
|
|
|
def test_ask_in_async_method(self) -> None:
|
|
"""ask() works inside an async flow method."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["async hello"])
|
|
|
|
@start()
|
|
async def my_method(self):
|
|
return self.ask("Say something:")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "async hello"
|
|
|
|
def test_ask_in_start_method(self) -> None:
|
|
"""ask() works inside a @start() method, flow completes normally."""
|
|
execution_log: list[str] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
topic = self.ask("Topic?")
|
|
execution_log.append(f"got:{topic}")
|
|
return topic
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "AI"
|
|
assert execution_log == ["got:AI"]
|
|
|
|
def test_ask_in_listen_method(self) -> None:
|
|
"""ask() works inside a @listen() method."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["detailed"])
|
|
|
|
@start()
|
|
def step1(self):
|
|
return "topic"
|
|
|
|
@listen("step1")
|
|
def step2(self):
|
|
depth = self.ask("How deep?")
|
|
return f"researching at {depth} level"
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "researching at detailed level"
|
|
|
|
def test_ask_multiple_calls(self) -> None:
|
|
"""Multiple ask() calls in one method return correct values in order."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI", "detailed", "english"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
topic = self.ask("Topic?")
|
|
depth = self.ask("Depth?")
|
|
lang = self.ask("Language?")
|
|
return {"topic": topic, "depth": depth, "lang": lang}
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == {"topic": "AI", "depth": "detailed", "lang": "english"}
|
|
|
|
def test_ask_conditional(self) -> None:
|
|
"""ask() called conditionally based on previous answer."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI", "LLMs"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
topic = self.ask("Topic?")
|
|
if topic == "AI":
|
|
focus = self.ask("Specific area?")
|
|
else:
|
|
focus = "general"
|
|
return {"topic": topic, "focus": focus}
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == {"topic": "AI", "focus": "LLMs"}
|
|
|
|
def test_ask_returns_empty_string_on_enter(self) -> None:
|
|
"""Empty string means user pressed Enter (intentional empty input)."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider([""])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
result = self.ask("Optional input:")
|
|
return result
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == ""
|
|
assert result is not None # Explicitly not None
|
|
|
|
|
|
# ── Timeout ──────────────────────────────────────────────────────
|
|
|
|
|
|
class TestAskTimeout:
|
|
"""Tests for timeout behavior."""
|
|
|
|
def test_ask_timeout_returns_none(self) -> None:
|
|
"""ask() returns None when timeout expires."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = SlowMockProvider(delay=5.0)
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?", timeout=0.1)
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result is None
|
|
|
|
def test_ask_timeout_in_async_method(self) -> None:
|
|
"""ask() timeout works inside an async flow method."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = SlowMockProvider(delay=5.0)
|
|
|
|
@start()
|
|
async def my_method(self):
|
|
return self.ask("Question?", timeout=0.1)
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result is None
|
|
|
|
def test_ask_loop_with_timeout_termination(self) -> None:
|
|
"""while (msg := ask(...)) is not None pattern terminates on timeout."""
|
|
messages_received: list[str] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["hello", "world", None])
|
|
|
|
@start()
|
|
def chat(self):
|
|
while (msg := self.ask("You:")) is not None:
|
|
messages_received.append(msg)
|
|
return len(messages_received)
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == 2
|
|
assert messages_received == ["hello", "world"]
|
|
|
|
def test_ask_no_timeout_waits_indefinitely(self) -> None:
|
|
"""ask() with no timeout blocks until provider returns."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?") # no timeout
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "answer"
|
|
|
|
|
|
# ── Provider Resolution ──────────────────────────────────────────
|
|
|
|
|
|
class TestProviderResolution:
|
|
"""Tests for provider resolution priority chain."""
|
|
|
|
def test_ask_uses_flow_level_provider(self) -> None:
|
|
"""Per-flow input_provider is used when set."""
|
|
provider = MockInputProvider(["from flow"])
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = provider
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
assert provider.messages == ["Q?"]
|
|
|
|
def test_ask_uses_global_config_provider(self) -> None:
|
|
"""flow_config.input_provider is used as fallback."""
|
|
provider = MockInputProvider(["from config"])
|
|
|
|
original = flow_config.input_provider
|
|
try:
|
|
flow_config.input_provider = provider
|
|
|
|
class TestFlow(Flow):
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "from config"
|
|
assert provider.messages == ["Q?"]
|
|
finally:
|
|
flow_config.input_provider = original
|
|
|
|
def test_ask_defaults_to_console_provider(self) -> None:
|
|
"""When no provider configured, ConsoleProvider is used."""
|
|
original = flow_config.input_provider
|
|
try:
|
|
flow_config.input_provider = None
|
|
|
|
class TestFlow(Flow):
|
|
# No input_provider set
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
resolved = flow._resolve_input_provider()
|
|
assert isinstance(resolved, ConsoleProvider)
|
|
finally:
|
|
flow_config.input_provider = original
|
|
|
|
def test_flow_provider_overrides_global(self) -> None:
|
|
"""Per-flow provider takes precedence over global config."""
|
|
flow_provider = MockInputProvider(["from flow"])
|
|
global_provider = MockInputProvider(["from global"])
|
|
|
|
original = flow_config.input_provider
|
|
try:
|
|
flow_config.input_provider = global_provider
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = flow_provider
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "from flow"
|
|
assert flow_provider.messages == ["Q?"]
|
|
assert global_provider.messages == [] # not called
|
|
finally:
|
|
flow_config.input_provider = original
|
|
|
|
|
|
# ── Events ───────────────────────────────────────────────────────
|
|
|
|
|
|
class TestAskEvents:
|
|
"""Tests for event emission during ask()."""
|
|
|
|
def test_ask_emits_input_requested_event(self) -> None:
|
|
"""FlowInputRequestedEvent is emitted when ask() is called."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import FlowInputRequestedEvent
|
|
|
|
events_captured: list[FlowInputRequestedEvent] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("What topic?")
|
|
|
|
flow = TestFlow()
|
|
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowInputRequestedEvent):
|
|
events_captured.append(event)
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert len(events_captured) == 1
|
|
assert events_captured[0].message == "What topic?"
|
|
assert events_captured[0].type == "flow_input_requested"
|
|
|
|
def test_ask_emits_input_received_event(self) -> None:
|
|
"""FlowInputReceivedEvent is emitted after input is received."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import FlowInputReceivedEvent
|
|
|
|
events_captured: list[FlowInputReceivedEvent] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["my answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?")
|
|
|
|
flow = TestFlow()
|
|
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowInputReceivedEvent):
|
|
events_captured.append(event)
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert len(events_captured) == 1
|
|
assert events_captured[0].message == "Question?"
|
|
assert events_captured[0].response == "my answer"
|
|
assert events_captured[0].type == "flow_input_received"
|
|
|
|
def test_ask_timeout_emits_received_with_none(self) -> None:
|
|
"""FlowInputReceivedEvent has response=None on timeout."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import FlowInputReceivedEvent
|
|
|
|
events_captured: list[FlowInputReceivedEvent] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = SlowMockProvider(delay=5.0)
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?", timeout=0.1)
|
|
|
|
flow = TestFlow()
|
|
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowInputReceivedEvent):
|
|
events_captured.append(event)
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert len(events_captured) == 1
|
|
assert events_captured[0].response is None
|
|
|
|
|
|
# ── Auto-checkpoint (Durability) ─────────────────────────────────
|
|
|
|
|
|
class TestAskCheckpoint:
|
|
"""Tests for auto-checkpoint durability before ask() waits."""
|
|
|
|
def test_ask_checkpoints_state_before_waiting(self) -> None:
|
|
"""State is saved to persistence before waiting for input."""
|
|
mock_persistence = MagicMock()
|
|
mock_persistence.load_state.return_value = None
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
self.state["important"] = "data"
|
|
return self.ask("Question?")
|
|
|
|
flow = TestFlow(persistence=mock_persistence)
|
|
flow.kickoff()
|
|
|
|
# Find the _ask_checkpoint call among save_state calls
|
|
checkpoint_calls = [
|
|
c for c in mock_persistence.save_state.call_args_list
|
|
if c.kwargs.get("method_name") == "_ask_checkpoint"
|
|
or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint")
|
|
]
|
|
assert len(checkpoint_calls) >= 1
|
|
|
|
def test_ask_no_checkpoint_without_persistence(self) -> None:
|
|
"""No error when persistence is not configured."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?")
|
|
|
|
flow = TestFlow() # No persistence
|
|
result = flow.kickoff()
|
|
assert result == "answer" # Works fine without persistence
|
|
|
|
def test_state_recoverable_after_checkpoint(self) -> None:
|
|
"""State set before ask() is checkpointed and recoverable.
|
|
|
|
The auto-checkpoint happens *before* the provider is called, so
|
|
state values set prior to ask() are persisted. This means if the
|
|
server crashes while waiting for input, previously gathered data
|
|
is safe.
|
|
"""
|
|
mock_persistence = MagicMock()
|
|
mock_persistence.load_state.return_value = None
|
|
|
|
class GatherFlow(Flow):
|
|
input_provider = MockInputProvider(["AI", "detailed"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
# First ask: nothing in state yet
|
|
topic = self.ask("Topic?")
|
|
self.state["topic"] = topic
|
|
# Second ask: state now has topic, checkpoint saves it
|
|
depth = self.ask("Depth?")
|
|
self.state["depth"] = depth
|
|
return {"topic": topic, "depth": depth}
|
|
|
|
flow = GatherFlow(persistence=mock_persistence)
|
|
result = flow.kickoff()
|
|
assert result == {"topic": "AI", "depth": "detailed"}
|
|
|
|
# Find the checkpoint calls
|
|
checkpoint_calls = [
|
|
c for c in mock_persistence.save_state.call_args_list
|
|
if c.kwargs.get("method_name") == "_ask_checkpoint"
|
|
or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint")
|
|
]
|
|
assert len(checkpoint_calls) == 2
|
|
|
|
# The second checkpoint (before asking "Depth?") should have topic
|
|
second_checkpoint = checkpoint_calls[1]
|
|
# state_data is the third positional arg or keyword arg
|
|
if second_checkpoint.kwargs.get("state_data"):
|
|
state_data = second_checkpoint.kwargs["state_data"]
|
|
else:
|
|
state_data = second_checkpoint.args[2]
|
|
assert state_data.get("topic") == "AI"
|
|
|
|
|
|
# ── Input History ────────────────────────────────────────────────
|
|
|
|
|
|
class TestInputHistory:
|
|
"""Tests for _input_history tracking."""
|
|
|
|
def test_input_history_accumulated(self) -> None:
|
|
"""_input_history tracks all ask/response pairs."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI", "detailed"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
self.ask("Topic?")
|
|
self.ask("Depth?")
|
|
return "done"
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
|
|
assert len(flow._input_history) == 2
|
|
assert flow._input_history[0]["message"] == "Topic?"
|
|
assert flow._input_history[0]["response"] == "AI"
|
|
assert flow._input_history[1]["message"] == "Depth?"
|
|
assert flow._input_history[1]["response"] == "detailed"
|
|
|
|
def test_input_history_includes_method_name(self) -> None:
|
|
"""Input history records which method called ask()."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI"])
|
|
|
|
@start()
|
|
def gather_info(self):
|
|
self.ask("Topic?")
|
|
return "done"
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
|
|
assert len(flow._input_history) == 1
|
|
assert flow._input_history[0]["method_name"] == "gather_info"
|
|
|
|
def test_input_history_includes_timestamp(self) -> None:
|
|
"""Input history records timestamps."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
self.ask("Topic?")
|
|
return "done"
|
|
|
|
flow = TestFlow()
|
|
before = datetime.now()
|
|
flow.kickoff()
|
|
after = datetime.now()
|
|
|
|
assert len(flow._input_history) == 1
|
|
ts = flow._input_history[0]["timestamp"]
|
|
assert isinstance(ts, datetime)
|
|
assert before <= ts <= after
|
|
|
|
def test_input_history_records_none_on_timeout(self) -> None:
|
|
"""Input history records None response on timeout."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = SlowMockProvider(delay=5.0)
|
|
|
|
@start()
|
|
def my_method(self):
|
|
self.ask("Question?", timeout=0.1)
|
|
return "done"
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
|
|
assert len(flow._input_history) == 1
|
|
assert flow._input_history[0]["response"] is None
|
|
|
|
|
|
# ── Integration ──────────────────────────────────────────────────
|
|
|
|
|
|
class TestAskIntegration:
|
|
"""Integration tests for ask() with other flow features."""
|
|
|
|
def test_ask_works_with_listen_chain(self) -> None:
|
|
"""ask() in a start method, result flows to listener."""
|
|
execution_log: list[str] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI agents"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
topic = self.ask("Topic?")
|
|
execution_log.append(f"gathered:{topic}")
|
|
return topic
|
|
|
|
@listen("gather")
|
|
def process(self):
|
|
execution_log.append("processing")
|
|
return "processed"
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
assert "gathered:AI agents" in execution_log
|
|
assert "processing" in execution_log
|
|
|
|
def test_ask_with_structured_state(self) -> None:
|
|
"""ask() works with Pydantic-based flow state."""
|
|
|
|
class ResearchState(FlowState):
|
|
topic: str = ""
|
|
depth: str = ""
|
|
|
|
class TestFlow(Flow[ResearchState]):
|
|
initial_state = ResearchState
|
|
input_provider = MockInputProvider(["AI", "detailed"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
self.state.topic = self.ask("Topic?")
|
|
self.state.depth = self.ask("Depth?")
|
|
return {"topic": self.state.topic, "depth": self.state.depth}
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == {"topic": "AI", "depth": "detailed"}
|
|
assert flow.state.topic == "AI"
|
|
assert flow.state.depth == "detailed"
|
|
|
|
def test_ask_in_async_method_with_listen_chain(self) -> None:
|
|
"""ask() in an async start method, result flows to listener."""
|
|
execution_log: list[str] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["async topic"])
|
|
|
|
@start()
|
|
async def gather(self):
|
|
topic = self.ask("Topic?")
|
|
execution_log.append(f"gathered:{topic}")
|
|
return topic
|
|
|
|
@listen("gather")
|
|
def process(self):
|
|
execution_log.append("processing")
|
|
return "processed"
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
assert "gathered:async topic" in execution_log
|
|
assert "processing" in execution_log
|
|
|
|
def test_ask_with_state_persistence_recovery(self) -> None:
|
|
"""Ask checkpoints state so previously gathered values survive."""
|
|
mock_persistence = MagicMock()
|
|
mock_persistence.load_state.return_value = None
|
|
|
|
class RecoverableFlow(Flow):
|
|
input_provider = MockInputProvider(["AI", "detailed"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
if not self.state.get("topic"):
|
|
self.state["topic"] = self.ask("Topic?")
|
|
if not self.state.get("depth"):
|
|
self.state["depth"] = self.ask("Depth?")
|
|
return {
|
|
"topic": self.state["topic"],
|
|
"depth": self.state["depth"],
|
|
}
|
|
|
|
flow = RecoverableFlow(persistence=mock_persistence)
|
|
result = flow.kickoff()
|
|
assert result["topic"] == "AI"
|
|
assert result["depth"] == "detailed"
|
|
|
|
# Verify checkpoints were made
|
|
checkpoint_calls = [
|
|
c for c in mock_persistence.save_state.call_args_list
|
|
if c.kwargs.get("method_name") == "_ask_checkpoint"
|
|
or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint")
|
|
]
|
|
# Two ask() calls = two checkpoints
|
|
assert len(checkpoint_calls) == 2
|
|
|
|
def test_ask_and_human_feedback_coexist(self) -> None:
|
|
"""ask() and @human_feedback can be used in the same flow."""
|
|
from crewai.flow import human_feedback
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
topic = self.ask("Topic?")
|
|
return topic
|
|
|
|
@listen("gather")
|
|
@human_feedback(message="Review this topic:")
|
|
def review(self):
|
|
return f"Researching: {self.state.get('_last_topic', 'unknown')}"
|
|
|
|
flow = TestFlow()
|
|
|
|
with patch.object(flow, "_request_human_feedback", return_value="looks good"):
|
|
flow.kickoff()
|
|
|
|
# Flow completed with both ask and human_feedback
|
|
assert flow.last_human_feedback is not None
|
|
|
|
def test_ask_preserves_flow_lifecycle(self) -> None:
|
|
"""Flow events (started, finished) still fire normally with ask()."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import (
|
|
FlowFinishedEvent,
|
|
FlowStartedEvent,
|
|
)
|
|
|
|
events_seen: list[str] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowStartedEvent):
|
|
events_seen.append("started")
|
|
elif isinstance(event, FlowFinishedEvent):
|
|
events_seen.append("finished")
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert "started" in events_seen
|
|
assert "finished" in events_seen
|
|
|
|
|
|
# ── Console Provider ─────────────────────────────────────────────
|
|
|
|
|
|
class TestConsoleProviderInput:
|
|
"""Tests for ConsoleProvider.request_input() (used by Flow.ask())."""
|
|
|
|
def test_console_provider_pauses_live_updates(self) -> None:
|
|
"""ConsoleProvider pauses and resumes formatter live updates."""
|
|
from crewai.events.event_listener import event_listener
|
|
|
|
mock_formatter = MagicMock()
|
|
mock_formatter.console = MagicMock()
|
|
|
|
provider = ConsoleProvider(verbose=True)
|
|
|
|
with (
|
|
patch.object(event_listener, "formatter", mock_formatter),
|
|
patch("builtins.input", return_value="test input"),
|
|
):
|
|
result = provider.request_input("Question?", MagicMock())
|
|
|
|
mock_formatter.pause_live_updates.assert_called_once()
|
|
mock_formatter.resume_live_updates.assert_called_once()
|
|
assert result == "test input"
|
|
|
|
def test_console_provider_displays_message(self) -> None:
|
|
"""ConsoleProvider displays the message with Rich console."""
|
|
from crewai.events.event_listener import event_listener
|
|
|
|
mock_formatter = MagicMock()
|
|
mock_console = MagicMock()
|
|
mock_formatter.console = mock_console
|
|
|
|
provider = ConsoleProvider(verbose=True)
|
|
|
|
with (
|
|
patch.object(event_listener, "formatter", mock_formatter),
|
|
patch("builtins.input", return_value="answer"),
|
|
):
|
|
provider.request_input("What topic?", MagicMock())
|
|
|
|
# Verify the message was printed
|
|
print_calls = [str(c) for c in mock_console.print.call_args_list]
|
|
assert any("What topic?" in c for c in print_calls)
|
|
|
|
def test_console_provider_non_verbose(self) -> None:
|
|
"""ConsoleProvider in non-verbose mode uses plain input."""
|
|
from crewai.events.event_listener import event_listener
|
|
|
|
mock_formatter = MagicMock()
|
|
mock_formatter.console = MagicMock()
|
|
|
|
provider = ConsoleProvider(verbose=False)
|
|
|
|
with (
|
|
patch.object(event_listener, "formatter", mock_formatter),
|
|
patch("builtins.input", return_value="plain answer") as mock_input,
|
|
):
|
|
result = provider.request_input("Q?", MagicMock())
|
|
|
|
assert result == "plain answer"
|
|
mock_input.assert_called_once_with("Q? ")
|
|
|
|
def test_console_provider_strips_response(self) -> None:
|
|
"""ConsoleProvider strips whitespace from response."""
|
|
from crewai.events.event_listener import event_listener
|
|
|
|
mock_formatter = MagicMock()
|
|
mock_formatter.console = MagicMock()
|
|
|
|
provider = ConsoleProvider(verbose=False)
|
|
|
|
with (
|
|
patch.object(event_listener, "formatter", mock_formatter),
|
|
patch("builtins.input", return_value=" spaced answer "),
|
|
):
|
|
result = provider.request_input("Q?", MagicMock())
|
|
|
|
assert result == "spaced answer"
|
|
|
|
def test_console_provider_implements_protocol(self) -> None:
|
|
"""ConsoleProvider satisfies the InputProvider protocol."""
|
|
provider = ConsoleProvider()
|
|
assert isinstance(provider, InputProvider)
|
|
|
|
|
|
# ── InputProvider Protocol ───────────────────────────────────────
|
|
|
|
|
|
class TestInputProviderProtocol:
|
|
"""Tests for the InputProvider protocol."""
|
|
|
|
def test_custom_provider_satisfies_protocol(self) -> None:
|
|
"""A class with request_input satisfies the InputProvider protocol."""
|
|
|
|
class MyProvider:
|
|
def request_input(self, message: str, flow: Flow[Any]) -> str | None:
|
|
return "custom"
|
|
|
|
provider = MyProvider()
|
|
assert isinstance(provider, InputProvider)
|
|
|
|
def test_mock_provider_satisfies_protocol(self) -> None:
|
|
"""MockInputProvider satisfies the InputProvider protocol."""
|
|
provider = MockInputProvider(["test"])
|
|
assert isinstance(provider, InputProvider)
|
|
|
|
|
|
# ── Error Handling ───────────────────────────────────────────────
|
|
|
|
|
|
class TestAskErrorHandling:
|
|
"""Tests for error handling in ask()."""
|
|
|
|
def test_ask_returns_none_on_provider_error(self) -> None:
|
|
"""ask() returns None if provider raises an exception."""
|
|
|
|
class FailingProvider:
|
|
def request_input(self, message: str, flow: Flow[Any]) -> str | None:
|
|
raise RuntimeError("Provider failed")
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = FailingProvider()
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result is None
|
|
|
|
def test_ask_in_async_method_returns_none_on_provider_error(self) -> None:
|
|
"""ask() returns None if provider raises in an async method."""
|
|
|
|
class FailingProvider:
|
|
def request_input(self, message: str, flow: Flow[Any]) -> str | None:
|
|
raise RuntimeError("Provider failed")
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = FailingProvider()
|
|
|
|
@start()
|
|
async def my_method(self):
|
|
return self.ask("Question?")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result is None
|
|
|
|
|
|
# ── Metadata ─────────────────────────────────────────────────────
|
|
|
|
|
|
class TestAskMetadata:
|
|
"""Tests for bidirectional metadata support in ask()."""
|
|
|
|
def test_ask_passes_metadata_to_provider(self) -> None:
|
|
"""Provider receives the metadata dict from ask()."""
|
|
provider = MockInputProvider(["answer"])
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = provider
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?", metadata={"user_id": "u123"})
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
assert provider.received_metadata == [{"user_id": "u123"}]
|
|
|
|
def test_ask_metadata_none_by_default(self) -> None:
|
|
"""Provider receives None metadata when not provided."""
|
|
provider = MockInputProvider(["answer"])
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = provider
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
assert provider.received_metadata == [None]
|
|
|
|
def test_ask_provider_returns_input_response(self) -> None:
|
|
"""Provider returns InputResponse with response metadata."""
|
|
|
|
class MetadataProvider:
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> InputResponse:
|
|
return InputResponse(
|
|
text="the answer",
|
|
metadata={"responded_by": "u456", "thread_id": "t789"},
|
|
)
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MetadataProvider()
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?", metadata={"user_id": "u123"})
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
|
|
# ask() still returns plain string
|
|
assert result == "the answer"
|
|
|
|
# History has both metadata dicts
|
|
assert len(flow._input_history) == 1
|
|
entry = flow._input_history[0]
|
|
assert entry["metadata"] == {"user_id": "u123"}
|
|
assert entry["response_metadata"] == {"responded_by": "u456", "thread_id": "t789"}
|
|
|
|
def test_ask_provider_returns_string_with_metadata_sent(self) -> None:
|
|
"""Provider returns plain string; history has metadata but no response_metadata."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?", metadata={"channel": "#research"})
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
|
|
entry = flow._input_history[0]
|
|
assert entry["metadata"] == {"channel": "#research"}
|
|
assert entry["response_metadata"] is None
|
|
|
|
def test_ask_metadata_in_requested_event(self) -> None:
|
|
"""FlowInputRequestedEvent carries metadata."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import FlowInputRequestedEvent
|
|
|
|
events_captured: list[FlowInputRequestedEvent] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?", metadata={"user_id": "u123"})
|
|
|
|
flow = TestFlow()
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowInputRequestedEvent):
|
|
events_captured.append(event)
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert len(events_captured) == 1
|
|
assert events_captured[0].metadata == {"user_id": "u123"}
|
|
|
|
def test_ask_metadata_in_received_event(self) -> None:
|
|
"""FlowInputReceivedEvent carries both metadata and response_metadata."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import FlowInputReceivedEvent
|
|
|
|
events_captured: list[FlowInputReceivedEvent] = []
|
|
|
|
class MetadataProvider:
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> InputResponse:
|
|
return InputResponse(text="answer", metadata={"responded_by": "u456"})
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MetadataProvider()
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?", metadata={"user_id": "u123"})
|
|
|
|
flow = TestFlow()
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowInputReceivedEvent):
|
|
events_captured.append(event)
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert len(events_captured) == 1
|
|
assert events_captured[0].metadata == {"user_id": "u123"}
|
|
assert events_captured[0].response_metadata == {"responded_by": "u456"}
|
|
assert events_captured[0].response == "answer"
|
|
|
|
def test_ask_input_response_with_none_text(self) -> None:
|
|
"""Provider returns InputResponse with text=None."""
|
|
|
|
class NoneTextProvider:
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> InputResponse:
|
|
return InputResponse(text=None, metadata={"reason": "user_declined"})
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = NoneTextProvider()
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result is None
|
|
|
|
entry = flow._input_history[0]
|
|
assert entry["response"] is None
|
|
assert entry["response_metadata"] == {"reason": "user_declined"}
|
|
|
|
def test_ask_metadata_thread_safe(self) -> None:
|
|
"""Concurrent ask() calls with different metadata don't cross-contaminate."""
|
|
import threading
|
|
|
|
call_log: list[dict[str, Any]] = []
|
|
log_lock = threading.Lock()
|
|
|
|
class TrackingProvider:
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> InputResponse:
|
|
# Small delay to increase chance of interleaving
|
|
time.sleep(0.05)
|
|
with log_lock:
|
|
call_log.append({"message": message, "metadata": metadata})
|
|
user = metadata.get("user", "unknown") if metadata else "unknown"
|
|
return InputResponse(
|
|
text=f"answer from {user}",
|
|
metadata={"responded_by": user},
|
|
)
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = TrackingProvider()
|
|
|
|
@start()
|
|
def trigger(self):
|
|
return "go"
|
|
|
|
@listen("trigger")
|
|
def listener_a(self):
|
|
return self.ask("Question A?", metadata={"user": "alice"})
|
|
|
|
@listen("trigger")
|
|
def listener_b(self):
|
|
return self.ask("Question B?", metadata={"user": "bob"})
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
|
|
# Both calls should have recorded their own metadata
|
|
assert len(flow._input_history) == 2
|
|
|
|
alice_entry = next(
|
|
(e for e in flow._input_history if e["metadata"] and e["metadata"].get("user") == "alice"),
|
|
None,
|
|
)
|
|
bob_entry = next(
|
|
(e for e in flow._input_history if e["metadata"] and e["metadata"].get("user") == "bob"),
|
|
None,
|
|
)
|
|
|
|
assert alice_entry is not None
|
|
assert alice_entry["response"] == "answer from alice"
|
|
assert alice_entry["response_metadata"] == {"responded_by": "alice"}
|
|
|
|
assert bob_entry is not None
|
|
assert bob_entry["response"] == "answer from bob"
|
|
assert bob_entry["response_metadata"] == {"responded_by": "bob"}
|