mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 21:58:11 +00:00
- Removed unnecessary variable assignments for the result of `flow.kickoff()` in two test cases, improving code clarity. - Updated assertions to ensure the expected execution log entries are present after the flow kickoff, enhancing test reliability.
1153 lines
39 KiB
Python
1153 lines
39 KiB
Python
"""Tests for Flow.ask() user input method.
|
|
|
|
This module tests the ask() method on Flow, including basic usage,
|
|
timeout behavior, provider resolution, event emission, auto-checkpoint
|
|
durability, input history tracking, and integration with flow machinery.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from datetime import datetime
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from crewai.flow import Flow, flow_config, listen, start
|
|
from crewai.flow.async_feedback.providers import ConsoleProvider
|
|
from crewai.flow.flow import FlowState
|
|
from crewai.flow.input_provider import InputProvider, InputResponse
|
|
|
|
|
|
# ── Test helpers ─────────────────────────────────────────────────
|
|
|
|
|
|
class MockInputProvider:
|
|
"""Mock input provider that returns pre-configured responses."""
|
|
|
|
def __init__(self, responses: list[str | None]) -> None:
|
|
self.responses = responses
|
|
self._call_count = 0
|
|
self.messages: list[str] = []
|
|
self.received_metadata: list[dict[str, Any] | None] = []
|
|
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> str | None:
|
|
self.messages.append(message)
|
|
self.received_metadata.append(metadata)
|
|
if self._call_count >= len(self.responses):
|
|
return None
|
|
response = self.responses[self._call_count]
|
|
self._call_count += 1
|
|
return response
|
|
|
|
|
|
class SlowMockProvider:
|
|
"""Mock provider that delays before returning, for timeout tests."""
|
|
|
|
def __init__(self, delay: float, response: str = "delayed") -> None:
|
|
self.delay = delay
|
|
self.response = response
|
|
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> str | None:
|
|
time.sleep(self.delay)
|
|
return self.response
|
|
|
|
|
|
# ── Basic Functionality ──────────────────────────────────────────
|
|
|
|
|
|
class TestAskBasic:
|
|
"""Tests for basic ask() functionality."""
|
|
|
|
def test_ask_returns_user_input(self) -> None:
|
|
"""ask() returns the string from the input provider."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["hello"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Say something:")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "hello"
|
|
|
|
def test_ask_in_async_method(self) -> None:
|
|
"""ask() works inside an async flow method."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["async hello"])
|
|
|
|
@start()
|
|
async def my_method(self):
|
|
return self.ask("Say something:")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "async hello"
|
|
|
|
def test_ask_in_start_method(self) -> None:
|
|
"""ask() works inside a @start() method, flow completes normally."""
|
|
execution_log: list[str] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
topic = self.ask("Topic?")
|
|
execution_log.append(f"got:{topic}")
|
|
return topic
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "AI"
|
|
assert execution_log == ["got:AI"]
|
|
|
|
def test_ask_in_listen_method(self) -> None:
|
|
"""ask() works inside a @listen() method."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["detailed"])
|
|
|
|
@start()
|
|
def step1(self):
|
|
return "topic"
|
|
|
|
@listen("step1")
|
|
def step2(self):
|
|
depth = self.ask("How deep?")
|
|
return f"researching at {depth} level"
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "researching at detailed level"
|
|
|
|
def test_ask_multiple_calls(self) -> None:
|
|
"""Multiple ask() calls in one method return correct values in order."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI", "detailed", "english"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
topic = self.ask("Topic?")
|
|
depth = self.ask("Depth?")
|
|
lang = self.ask("Language?")
|
|
return {"topic": topic, "depth": depth, "lang": lang}
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == {"topic": "AI", "depth": "detailed", "lang": "english"}
|
|
|
|
def test_ask_conditional(self) -> None:
|
|
"""ask() called conditionally based on previous answer."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI", "LLMs"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
topic = self.ask("Topic?")
|
|
if topic == "AI":
|
|
focus = self.ask("Specific area?")
|
|
else:
|
|
focus = "general"
|
|
return {"topic": topic, "focus": focus}
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == {"topic": "AI", "focus": "LLMs"}
|
|
|
|
def test_ask_returns_empty_string_on_enter(self) -> None:
|
|
"""Empty string means user pressed Enter (intentional empty input)."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider([""])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
result = self.ask("Optional input:")
|
|
return result
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == ""
|
|
assert result is not None # Explicitly not None
|
|
|
|
|
|
# ── Timeout ──────────────────────────────────────────────────────
|
|
|
|
|
|
class TestAskTimeout:
|
|
"""Tests for timeout behavior."""
|
|
|
|
def test_ask_timeout_returns_none(self) -> None:
|
|
"""ask() returns None when timeout expires."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = SlowMockProvider(delay=5.0)
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?", timeout=0.1)
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result is None
|
|
|
|
def test_ask_timeout_in_async_method(self) -> None:
|
|
"""ask() timeout works inside an async flow method."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = SlowMockProvider(delay=5.0)
|
|
|
|
@start()
|
|
async def my_method(self):
|
|
return self.ask("Question?", timeout=0.1)
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result is None
|
|
|
|
def test_ask_loop_with_timeout_termination(self) -> None:
|
|
"""while (msg := ask(...)) is not None pattern terminates on timeout."""
|
|
messages_received: list[str] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["hello", "world", None])
|
|
|
|
@start()
|
|
def chat(self):
|
|
while (msg := self.ask("You:")) is not None:
|
|
messages_received.append(msg)
|
|
return len(messages_received)
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == 2
|
|
assert messages_received == ["hello", "world"]
|
|
|
|
def test_ask_no_timeout_waits_indefinitely(self) -> None:
|
|
"""ask() with no timeout blocks until provider returns."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?") # no timeout
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "answer"
|
|
|
|
|
|
# ── Provider Resolution ──────────────────────────────────────────
|
|
|
|
|
|
class TestProviderResolution:
|
|
"""Tests for provider resolution priority chain."""
|
|
|
|
def test_ask_uses_flow_level_provider(self) -> None:
|
|
"""Per-flow input_provider is used when set."""
|
|
provider = MockInputProvider(["from flow"])
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = provider
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
assert provider.messages == ["Q?"]
|
|
|
|
def test_ask_uses_global_config_provider(self) -> None:
|
|
"""flow_config.input_provider is used as fallback."""
|
|
provider = MockInputProvider(["from config"])
|
|
|
|
original = flow_config.input_provider
|
|
try:
|
|
flow_config.input_provider = provider
|
|
|
|
class TestFlow(Flow):
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "from config"
|
|
assert provider.messages == ["Q?"]
|
|
finally:
|
|
flow_config.input_provider = original
|
|
|
|
def test_ask_defaults_to_console_provider(self) -> None:
|
|
"""When no provider configured, ConsoleProvider is used."""
|
|
original = flow_config.input_provider
|
|
try:
|
|
flow_config.input_provider = None
|
|
|
|
class TestFlow(Flow):
|
|
# No input_provider set
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
resolved = flow._resolve_input_provider()
|
|
assert isinstance(resolved, ConsoleProvider)
|
|
finally:
|
|
flow_config.input_provider = original
|
|
|
|
def test_flow_provider_overrides_global(self) -> None:
|
|
"""Per-flow provider takes precedence over global config."""
|
|
flow_provider = MockInputProvider(["from flow"])
|
|
global_provider = MockInputProvider(["from global"])
|
|
|
|
original = flow_config.input_provider
|
|
try:
|
|
flow_config.input_provider = global_provider
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = flow_provider
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == "from flow"
|
|
assert flow_provider.messages == ["Q?"]
|
|
assert global_provider.messages == [] # not called
|
|
finally:
|
|
flow_config.input_provider = original
|
|
|
|
|
|
# ── Events ───────────────────────────────────────────────────────
|
|
|
|
|
|
class TestAskEvents:
|
|
"""Tests for event emission during ask()."""
|
|
|
|
def test_ask_emits_input_requested_event(self) -> None:
|
|
"""FlowInputRequestedEvent is emitted when ask() is called."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import FlowInputRequestedEvent
|
|
|
|
events_captured: list[FlowInputRequestedEvent] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("What topic?")
|
|
|
|
flow = TestFlow()
|
|
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowInputRequestedEvent):
|
|
events_captured.append(event)
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert len(events_captured) == 1
|
|
assert events_captured[0].message == "What topic?"
|
|
assert events_captured[0].type == "flow_input_requested"
|
|
|
|
def test_ask_emits_input_received_event(self) -> None:
|
|
"""FlowInputReceivedEvent is emitted after input is received."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import FlowInputReceivedEvent
|
|
|
|
events_captured: list[FlowInputReceivedEvent] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["my answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?")
|
|
|
|
flow = TestFlow()
|
|
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowInputReceivedEvent):
|
|
events_captured.append(event)
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert len(events_captured) == 1
|
|
assert events_captured[0].message == "Question?"
|
|
assert events_captured[0].response == "my answer"
|
|
assert events_captured[0].type == "flow_input_received"
|
|
|
|
def test_ask_timeout_emits_received_with_none(self) -> None:
|
|
"""FlowInputReceivedEvent has response=None on timeout."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import FlowInputReceivedEvent
|
|
|
|
events_captured: list[FlowInputReceivedEvent] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = SlowMockProvider(delay=5.0)
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?", timeout=0.1)
|
|
|
|
flow = TestFlow()
|
|
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowInputReceivedEvent):
|
|
events_captured.append(event)
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert len(events_captured) == 1
|
|
assert events_captured[0].response is None
|
|
|
|
|
|
# ── Auto-checkpoint (Durability) ─────────────────────────────────
|
|
|
|
|
|
class TestAskCheckpoint:
|
|
"""Tests for auto-checkpoint durability before ask() waits."""
|
|
|
|
def test_ask_checkpoints_state_before_waiting(self) -> None:
|
|
"""State is saved to persistence before waiting for input."""
|
|
mock_persistence = MagicMock()
|
|
mock_persistence.load_state.return_value = None
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
self.state["important"] = "data"
|
|
return self.ask("Question?")
|
|
|
|
flow = TestFlow(persistence=mock_persistence)
|
|
flow.kickoff()
|
|
|
|
# Find the _ask_checkpoint call among save_state calls
|
|
checkpoint_calls = [
|
|
c for c in mock_persistence.save_state.call_args_list
|
|
if c.kwargs.get("method_name") == "_ask_checkpoint"
|
|
or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint")
|
|
]
|
|
assert len(checkpoint_calls) >= 1
|
|
|
|
def test_ask_no_checkpoint_without_persistence(self) -> None:
|
|
"""No error when persistence is not configured."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?")
|
|
|
|
flow = TestFlow() # No persistence
|
|
result = flow.kickoff()
|
|
assert result == "answer" # Works fine without persistence
|
|
|
|
def test_state_recoverable_after_checkpoint(self) -> None:
|
|
"""State set before ask() is checkpointed and recoverable.
|
|
|
|
The auto-checkpoint happens *before* the provider is called, so
|
|
state values set prior to ask() are persisted. This means if the
|
|
server crashes while waiting for input, previously gathered data
|
|
is safe.
|
|
"""
|
|
mock_persistence = MagicMock()
|
|
mock_persistence.load_state.return_value = None
|
|
|
|
class GatherFlow(Flow):
|
|
input_provider = MockInputProvider(["AI", "detailed"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
# First ask: nothing in state yet
|
|
topic = self.ask("Topic?")
|
|
self.state["topic"] = topic
|
|
# Second ask: state now has topic, checkpoint saves it
|
|
depth = self.ask("Depth?")
|
|
self.state["depth"] = depth
|
|
return {"topic": topic, "depth": depth}
|
|
|
|
flow = GatherFlow(persistence=mock_persistence)
|
|
result = flow.kickoff()
|
|
assert result == {"topic": "AI", "depth": "detailed"}
|
|
|
|
# Find the checkpoint calls
|
|
checkpoint_calls = [
|
|
c for c in mock_persistence.save_state.call_args_list
|
|
if c.kwargs.get("method_name") == "_ask_checkpoint"
|
|
or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint")
|
|
]
|
|
assert len(checkpoint_calls) == 2
|
|
|
|
# The second checkpoint (before asking "Depth?") should have topic
|
|
second_checkpoint = checkpoint_calls[1]
|
|
# state_data is the third positional arg or keyword arg
|
|
if second_checkpoint.kwargs.get("state_data"):
|
|
state_data = second_checkpoint.kwargs["state_data"]
|
|
else:
|
|
state_data = second_checkpoint.args[2]
|
|
assert state_data.get("topic") == "AI"
|
|
|
|
|
|
# ── Input History ────────────────────────────────────────────────
|
|
|
|
|
|
class TestInputHistory:
|
|
"""Tests for _input_history tracking."""
|
|
|
|
def test_input_history_accumulated(self) -> None:
|
|
"""_input_history tracks all ask/response pairs."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI", "detailed"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
self.ask("Topic?")
|
|
self.ask("Depth?")
|
|
return "done"
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
|
|
assert len(flow._input_history) == 2
|
|
assert flow._input_history[0]["message"] == "Topic?"
|
|
assert flow._input_history[0]["response"] == "AI"
|
|
assert flow._input_history[1]["message"] == "Depth?"
|
|
assert flow._input_history[1]["response"] == "detailed"
|
|
|
|
def test_input_history_includes_method_name(self) -> None:
|
|
"""Input history records which method called ask()."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI"])
|
|
|
|
@start()
|
|
def gather_info(self):
|
|
self.ask("Topic?")
|
|
return "done"
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
|
|
assert len(flow._input_history) == 1
|
|
assert flow._input_history[0]["method_name"] == "gather_info"
|
|
|
|
def test_input_history_includes_timestamp(self) -> None:
|
|
"""Input history records timestamps."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
self.ask("Topic?")
|
|
return "done"
|
|
|
|
flow = TestFlow()
|
|
before = datetime.now()
|
|
flow.kickoff()
|
|
after = datetime.now()
|
|
|
|
assert len(flow._input_history) == 1
|
|
ts = flow._input_history[0]["timestamp"]
|
|
assert isinstance(ts, datetime)
|
|
assert before <= ts <= after
|
|
|
|
def test_input_history_records_none_on_timeout(self) -> None:
|
|
"""Input history records None response on timeout."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = SlowMockProvider(delay=5.0)
|
|
|
|
@start()
|
|
def my_method(self):
|
|
self.ask("Question?", timeout=0.1)
|
|
return "done"
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
|
|
assert len(flow._input_history) == 1
|
|
assert flow._input_history[0]["response"] is None
|
|
|
|
|
|
# ── Integration ──────────────────────────────────────────────────
|
|
|
|
|
|
class TestAskIntegration:
|
|
"""Integration tests for ask() with other flow features."""
|
|
|
|
def test_ask_works_with_listen_chain(self) -> None:
|
|
"""ask() in a start method, result flows to listener."""
|
|
execution_log: list[str] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI agents"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
topic = self.ask("Topic?")
|
|
execution_log.append(f"gathered:{topic}")
|
|
return topic
|
|
|
|
@listen("gather")
|
|
def process(self):
|
|
execution_log.append("processing")
|
|
return "processed"
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
assert "gathered:AI agents" in execution_log
|
|
assert "processing" in execution_log
|
|
|
|
def test_ask_with_structured_state(self) -> None:
|
|
"""ask() works with Pydantic-based flow state."""
|
|
|
|
class ResearchState(FlowState):
|
|
topic: str = ""
|
|
depth: str = ""
|
|
|
|
class TestFlow(Flow[ResearchState]):
|
|
initial_state = ResearchState
|
|
input_provider = MockInputProvider(["AI", "detailed"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
self.state.topic = self.ask("Topic?")
|
|
self.state.depth = self.ask("Depth?")
|
|
return {"topic": self.state.topic, "depth": self.state.depth}
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result == {"topic": "AI", "depth": "detailed"}
|
|
assert flow.state.topic == "AI"
|
|
assert flow.state.depth == "detailed"
|
|
|
|
def test_ask_in_async_method_with_listen_chain(self) -> None:
|
|
"""ask() in an async start method, result flows to listener."""
|
|
execution_log: list[str] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["async topic"])
|
|
|
|
@start()
|
|
async def gather(self):
|
|
topic = self.ask("Topic?")
|
|
execution_log.append(f"gathered:{topic}")
|
|
return topic
|
|
|
|
@listen("gather")
|
|
def process(self):
|
|
execution_log.append("processing")
|
|
return "processed"
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
assert "gathered:async topic" in execution_log
|
|
assert "processing" in execution_log
|
|
|
|
def test_ask_with_state_persistence_recovery(self) -> None:
|
|
"""Ask checkpoints state so previously gathered values survive."""
|
|
mock_persistence = MagicMock()
|
|
mock_persistence.load_state.return_value = None
|
|
|
|
class RecoverableFlow(Flow):
|
|
input_provider = MockInputProvider(["AI", "detailed"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
if not self.state.get("topic"):
|
|
self.state["topic"] = self.ask("Topic?")
|
|
if not self.state.get("depth"):
|
|
self.state["depth"] = self.ask("Depth?")
|
|
return {
|
|
"topic": self.state["topic"],
|
|
"depth": self.state["depth"],
|
|
}
|
|
|
|
flow = RecoverableFlow(persistence=mock_persistence)
|
|
result = flow.kickoff()
|
|
assert result["topic"] == "AI"
|
|
assert result["depth"] == "detailed"
|
|
|
|
# Verify checkpoints were made
|
|
checkpoint_calls = [
|
|
c for c in mock_persistence.save_state.call_args_list
|
|
if c.kwargs.get("method_name") == "_ask_checkpoint"
|
|
or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint")
|
|
]
|
|
# Two ask() calls = two checkpoints
|
|
assert len(checkpoint_calls) == 2
|
|
|
|
def test_ask_and_human_feedback_coexist(self) -> None:
|
|
"""ask() and @human_feedback can be used in the same flow."""
|
|
from crewai.flow import human_feedback
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["AI"])
|
|
|
|
@start()
|
|
def gather(self):
|
|
topic = self.ask("Topic?")
|
|
return topic
|
|
|
|
@listen("gather")
|
|
@human_feedback(message="Review this topic:")
|
|
def review(self):
|
|
return f"Researching: {self.state.get('_last_topic', 'unknown')}"
|
|
|
|
flow = TestFlow()
|
|
|
|
with patch.object(flow, "_request_human_feedback", return_value="looks good"):
|
|
flow.kickoff()
|
|
|
|
# Flow completed with both ask and human_feedback
|
|
assert flow.last_human_feedback is not None
|
|
|
|
def test_ask_preserves_flow_lifecycle(self) -> None:
|
|
"""Flow events (started, finished) still fire normally with ask()."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import (
|
|
FlowFinishedEvent,
|
|
FlowStartedEvent,
|
|
)
|
|
|
|
events_seen: list[str] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowStartedEvent):
|
|
events_seen.append("started")
|
|
elif isinstance(event, FlowFinishedEvent):
|
|
events_seen.append("finished")
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert "started" in events_seen
|
|
assert "finished" in events_seen
|
|
|
|
|
|
# ── Console Provider ─────────────────────────────────────────────
|
|
|
|
|
|
class TestConsoleProviderInput:
|
|
"""Tests for ConsoleProvider.request_input() (used by Flow.ask())."""
|
|
|
|
def test_console_provider_pauses_live_updates(self) -> None:
|
|
"""ConsoleProvider pauses and resumes formatter live updates."""
|
|
from crewai.events.event_listener import event_listener
|
|
|
|
mock_formatter = MagicMock()
|
|
mock_formatter.console = MagicMock()
|
|
|
|
provider = ConsoleProvider(verbose=True)
|
|
|
|
with (
|
|
patch.object(event_listener, "formatter", mock_formatter),
|
|
patch("builtins.input", return_value="test input"),
|
|
):
|
|
result = provider.request_input("Question?", MagicMock())
|
|
|
|
mock_formatter.pause_live_updates.assert_called_once()
|
|
mock_formatter.resume_live_updates.assert_called_once()
|
|
assert result == "test input"
|
|
|
|
def test_console_provider_displays_message(self) -> None:
|
|
"""ConsoleProvider displays the message with Rich console."""
|
|
from crewai.events.event_listener import event_listener
|
|
|
|
mock_formatter = MagicMock()
|
|
mock_console = MagicMock()
|
|
mock_formatter.console = mock_console
|
|
|
|
provider = ConsoleProvider(verbose=True)
|
|
|
|
with (
|
|
patch.object(event_listener, "formatter", mock_formatter),
|
|
patch("builtins.input", return_value="answer"),
|
|
):
|
|
provider.request_input("What topic?", MagicMock())
|
|
|
|
# Verify the message was printed
|
|
print_calls = [str(c) for c in mock_console.print.call_args_list]
|
|
assert any("What topic?" in c for c in print_calls)
|
|
|
|
def test_console_provider_non_verbose(self) -> None:
|
|
"""ConsoleProvider in non-verbose mode uses plain input."""
|
|
from crewai.events.event_listener import event_listener
|
|
|
|
mock_formatter = MagicMock()
|
|
mock_formatter.console = MagicMock()
|
|
|
|
provider = ConsoleProvider(verbose=False)
|
|
|
|
with (
|
|
patch.object(event_listener, "formatter", mock_formatter),
|
|
patch("builtins.input", return_value="plain answer") as mock_input,
|
|
):
|
|
result = provider.request_input("Q?", MagicMock())
|
|
|
|
assert result == "plain answer"
|
|
mock_input.assert_called_once_with("Q? ")
|
|
|
|
def test_console_provider_strips_response(self) -> None:
|
|
"""ConsoleProvider strips whitespace from response."""
|
|
from crewai.events.event_listener import event_listener
|
|
|
|
mock_formatter = MagicMock()
|
|
mock_formatter.console = MagicMock()
|
|
|
|
provider = ConsoleProvider(verbose=False)
|
|
|
|
with (
|
|
patch.object(event_listener, "formatter", mock_formatter),
|
|
patch("builtins.input", return_value=" spaced answer "),
|
|
):
|
|
result = provider.request_input("Q?", MagicMock())
|
|
|
|
assert result == "spaced answer"
|
|
|
|
def test_console_provider_implements_protocol(self) -> None:
|
|
"""ConsoleProvider satisfies the InputProvider protocol."""
|
|
provider = ConsoleProvider()
|
|
assert isinstance(provider, InputProvider)
|
|
|
|
|
|
# ── InputProvider Protocol ───────────────────────────────────────
|
|
|
|
|
|
class TestInputProviderProtocol:
|
|
"""Tests for the InputProvider protocol."""
|
|
|
|
def test_custom_provider_satisfies_protocol(self) -> None:
|
|
"""A class with request_input satisfies the InputProvider protocol."""
|
|
|
|
class MyProvider:
|
|
def request_input(self, message: str, flow: Flow[Any]) -> str | None:
|
|
return "custom"
|
|
|
|
provider = MyProvider()
|
|
assert isinstance(provider, InputProvider)
|
|
|
|
def test_mock_provider_satisfies_protocol(self) -> None:
|
|
"""MockInputProvider satisfies the InputProvider protocol."""
|
|
provider = MockInputProvider(["test"])
|
|
assert isinstance(provider, InputProvider)
|
|
|
|
|
|
# ── Error Handling ───────────────────────────────────────────────
|
|
|
|
|
|
class TestAskErrorHandling:
|
|
"""Tests for error handling in ask()."""
|
|
|
|
def test_ask_returns_none_on_provider_error(self) -> None:
|
|
"""ask() returns None if provider raises an exception."""
|
|
|
|
class FailingProvider:
|
|
def request_input(self, message: str, flow: Flow[Any]) -> str | None:
|
|
raise RuntimeError("Provider failed")
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = FailingProvider()
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Question?")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result is None
|
|
|
|
def test_ask_in_async_method_returns_none_on_provider_error(self) -> None:
|
|
"""ask() returns None if provider raises in an async method."""
|
|
|
|
class FailingProvider:
|
|
def request_input(self, message: str, flow: Flow[Any]) -> str | None:
|
|
raise RuntimeError("Provider failed")
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = FailingProvider()
|
|
|
|
@start()
|
|
async def my_method(self):
|
|
return self.ask("Question?")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result is None
|
|
|
|
|
|
# ── Metadata ─────────────────────────────────────────────────────
|
|
|
|
|
|
class TestAskMetadata:
|
|
"""Tests for bidirectional metadata support in ask()."""
|
|
|
|
def test_ask_passes_metadata_to_provider(self) -> None:
|
|
"""Provider receives the metadata dict from ask()."""
|
|
provider = MockInputProvider(["answer"])
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = provider
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?", metadata={"user_id": "u123"})
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
assert provider.received_metadata == [{"user_id": "u123"}]
|
|
|
|
def test_ask_metadata_none_by_default(self) -> None:
|
|
"""Provider receives None metadata when not provided."""
|
|
provider = MockInputProvider(["answer"])
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = provider
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
assert provider.received_metadata == [None]
|
|
|
|
def test_ask_provider_returns_input_response(self) -> None:
|
|
"""Provider returns InputResponse with response metadata."""
|
|
|
|
class MetadataProvider:
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> InputResponse:
|
|
return InputResponse(
|
|
text="the answer",
|
|
metadata={"responded_by": "u456", "thread_id": "t789"},
|
|
)
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MetadataProvider()
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?", metadata={"user_id": "u123"})
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
|
|
# ask() still returns plain string
|
|
assert result == "the answer"
|
|
|
|
# History has both metadata dicts
|
|
assert len(flow._input_history) == 1
|
|
entry = flow._input_history[0]
|
|
assert entry["metadata"] == {"user_id": "u123"}
|
|
assert entry["response_metadata"] == {"responded_by": "u456", "thread_id": "t789"}
|
|
|
|
def test_ask_provider_returns_string_with_metadata_sent(self) -> None:
|
|
"""Provider returns plain string; history has metadata but no response_metadata."""
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?", metadata={"channel": "#research"})
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
|
|
entry = flow._input_history[0]
|
|
assert entry["metadata"] == {"channel": "#research"}
|
|
assert entry["response_metadata"] is None
|
|
|
|
def test_ask_metadata_in_requested_event(self) -> None:
|
|
"""FlowInputRequestedEvent carries metadata."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import FlowInputRequestedEvent
|
|
|
|
events_captured: list[FlowInputRequestedEvent] = []
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MockInputProvider(["answer"])
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?", metadata={"user_id": "u123"})
|
|
|
|
flow = TestFlow()
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowInputRequestedEvent):
|
|
events_captured.append(event)
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert len(events_captured) == 1
|
|
assert events_captured[0].metadata == {"user_id": "u123"}
|
|
|
|
def test_ask_metadata_in_received_event(self) -> None:
|
|
"""FlowInputReceivedEvent carries both metadata and response_metadata."""
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import FlowInputReceivedEvent
|
|
|
|
events_captured: list[FlowInputReceivedEvent] = []
|
|
|
|
class MetadataProvider:
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> InputResponse:
|
|
return InputResponse(text="answer", metadata={"responded_by": "u456"})
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = MetadataProvider()
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?", metadata={"user_id": "u123"})
|
|
|
|
flow = TestFlow()
|
|
original_emit = crewai_event_bus.emit
|
|
|
|
def capture_emit(source: Any, event: Any) -> Any:
|
|
if isinstance(event, FlowInputReceivedEvent):
|
|
events_captured.append(event)
|
|
return original_emit(source, event)
|
|
|
|
with patch.object(crewai_event_bus, "emit", side_effect=capture_emit):
|
|
flow.kickoff()
|
|
|
|
assert len(events_captured) == 1
|
|
assert events_captured[0].metadata == {"user_id": "u123"}
|
|
assert events_captured[0].response_metadata == {"responded_by": "u456"}
|
|
assert events_captured[0].response == "answer"
|
|
|
|
def test_ask_input_response_with_none_text(self) -> None:
|
|
"""Provider returns InputResponse with text=None."""
|
|
|
|
class NoneTextProvider:
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> InputResponse:
|
|
return InputResponse(text=None, metadata={"reason": "user_declined"})
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = NoneTextProvider()
|
|
|
|
@start()
|
|
def my_method(self):
|
|
return self.ask("Q?")
|
|
|
|
flow = TestFlow()
|
|
result = flow.kickoff()
|
|
assert result is None
|
|
|
|
entry = flow._input_history[0]
|
|
assert entry["response"] is None
|
|
assert entry["response_metadata"] == {"reason": "user_declined"}
|
|
|
|
def test_ask_metadata_thread_safe(self) -> None:
|
|
"""Concurrent ask() calls with different metadata don't cross-contaminate."""
|
|
import threading
|
|
|
|
call_log: list[dict[str, Any]] = []
|
|
log_lock = threading.Lock()
|
|
|
|
class TrackingProvider:
|
|
def request_input(
|
|
self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None
|
|
) -> InputResponse:
|
|
# Small delay to increase chance of interleaving
|
|
time.sleep(0.05)
|
|
with log_lock:
|
|
call_log.append({"message": message, "metadata": metadata})
|
|
user = metadata.get("user", "unknown") if metadata else "unknown"
|
|
return InputResponse(
|
|
text=f"answer from {user}",
|
|
metadata={"responded_by": user},
|
|
)
|
|
|
|
class TestFlow(Flow):
|
|
input_provider = TrackingProvider()
|
|
|
|
@start()
|
|
def trigger(self):
|
|
return "go"
|
|
|
|
@listen("trigger")
|
|
def listener_a(self):
|
|
return self.ask("Question A?", metadata={"user": "alice"})
|
|
|
|
@listen("trigger")
|
|
def listener_b(self):
|
|
return self.ask("Question B?", metadata={"user": "bob"})
|
|
|
|
flow = TestFlow()
|
|
flow.kickoff()
|
|
|
|
# Both calls should have recorded their own metadata
|
|
assert len(flow._input_history) == 2
|
|
|
|
alice_entry = next(
|
|
(e for e in flow._input_history if e["metadata"] and e["metadata"].get("user") == "alice"),
|
|
None,
|
|
)
|
|
bob_entry = next(
|
|
(e for e in flow._input_history if e["metadata"] and e["metadata"].get("user") == "bob"),
|
|
None,
|
|
)
|
|
|
|
assert alice_entry is not None
|
|
assert alice_entry["response"] == "answer from alice"
|
|
assert alice_entry["response_metadata"] == {"responded_by": "alice"}
|
|
|
|
assert bob_entry is not None
|
|
assert bob_entry["response"] == "answer from bob"
|
|
assert bob_entry["response_metadata"] == {"responded_by": "bob"}
|