Compare commits

..

2 Commits

Author SHA1 Message Date
João Moura
84d57c7a24 Implement user input handling in Flows (#4490)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
* Implement user input handling in Flow class
2026-02-16 18:41:03 -03:00
João Moura
4aedd58829 Enhance HITL self-loop functionality in human feedback integration tests (#4493)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
- Added tests to verify self-loop behavior in HITL routers, ensuring they can handle multiple rejections and immediate approvals.
- Implemented `test_hitl_self_loop_routes_back_to_same_method`, `test_hitl_self_loop_multiple_rejections`, and `test_hitl_self_loop_immediate_approval` to validate the expected execution order and outcomes.
- Updated the `or_()` listener to support looping back to the same method based on human feedback outcomes, improving flow control in complex scenarios.
2026-02-15 21:54:42 -05:00
18 changed files with 2325 additions and 206 deletions

View File

@@ -1,7 +1,5 @@
import os
from crewai.context import get_platform_integration_token as _get_context_token
def get_platform_api_base_url() -> str:
"""Get the platform API base URL from environment or use default."""
@@ -10,16 +8,10 @@ def get_platform_api_base_url() -> str:
def get_platform_integration_token() -> str:
"""Get the platform integration token from the context.
Fallback to the environment variable if no token has been set in the context.
Raises:
ValueError: If no token has been set in the context.
"""
token = _get_context_token() or os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
"""Get the platform API base URL from environment or use default."""
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN") or ""
if not token:
raise ValueError(
"No platform integration token found. "
"Set it via platform_integration_context() or set_platform_integration_token()."
"No platform integration token found, please set the CREWAI_PLATFORM_INTEGRATION_TOKEN environment variable"
)
return token
return token # TODO: Use context manager to get token

View File

@@ -1,45 +0,0 @@
"""Tests for platform tools misc functionality."""
import os
from unittest.mock import patch
import pytest
from crewai.context import platform_integration_context
from crewai_tools.tools.crewai_platform_tools.misc import (
get_platform_integration_token,
)
class TestTokenRetrievalWithFallback:
"""Test token retrieval logic with environment fallback."""
def test_context_token_takes_precedence(self, clean_context):
"""Test that context token takes precedence over environment variable."""
context_token = "context-token"
env_token = "env-token"
with patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": env_token}):
with platform_integration_context(context_token):
token = get_platform_integration_token()
assert token == context_token
def test_environment_fallback_when_no_context(self, clean_context):
"""Test fallback to environment variable when no context token."""
env_token = "env-fallback-token"
with patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": env_token}):
token = get_platform_integration_token()
assert token == env_token
@pytest.mark.parametrize("empty_value", ["", None])
def test_missing_token_raises_error(self, clean_context, empty_value):
"""Test that missing tokens raise appropriate errors."""
env_dict = {"CREWAI_PLATFORM_INTEGRATION_TOKEN": empty_value} if empty_value is not None else {}
with patch.dict(os.environ, env_dict, clear=True):
with pytest.raises(ValueError) as exc_info:
get_platform_integration_token()
assert "No platform integration token found" in str(exc_info.value)
assert "platform_integration_context()" in str(exc_info.value)

View File

@@ -20117,6 +20117,18 @@
"humanized_name": "Web Automation Tool",
"init_params_schema": {
"$defs": {
"AvailableModel": {
"enum": [
"gpt-4o",
"gpt-4o-mini",
"claude-3-5-sonnet-latest",
"claude-3-7-sonnet-latest",
"computer-use-preview",
"gemini-2.0-flash"
],
"title": "AvailableModel",
"type": "string"
},
"EnvVar": {
"properties": {
"default": {
@@ -20194,6 +20206,17 @@
"default": null,
"title": "Model Api Key"
},
"model_name": {
"anyOf": [
{
"$ref": "#/$defs/AvailableModel"
},
{
"type": "null"
}
],
"default": "claude-3-7-sonnet-latest"
},
"project_id": {
"anyOf": [
{

View File

@@ -2,7 +2,30 @@ import subprocess
import click
from crewai.cli.utils import get_crews
from crewai.cli.utils import get_crews, get_flows
from crewai.flow import Flow
def _reset_flow_memory(flow: Flow) -> None:
"""Reset memory for a single flow instance.
Handles Memory, MemoryScope (both have .reset()), and MemorySlice
(delegates to the underlying ._memory). Silently succeeds when the
storage directory does not exist yet (nothing to reset).
Args:
flow: The flow instance whose memory should be reset.
"""
mem = flow.memory
if mem is None:
return
try:
if hasattr(mem, "reset"):
mem.reset()
elif hasattr(mem, "_memory") and hasattr(mem._memory, "reset"):
mem._memory.reset()
except (FileNotFoundError, OSError):
pass
def reset_memories_command(
@@ -12,7 +35,7 @@ def reset_memories_command(
kickoff_outputs: bool,
all: bool,
) -> None:
"""Reset the crew memories.
"""Reset the crew and flow memories.
Args:
memory: Whether to reset the unified memory.
@@ -29,8 +52,11 @@ def reset_memories_command(
return
crews = get_crews()
if not crews:
raise ValueError("No crew found.")
flows = get_flows()
if not crews and not flows:
raise ValueError("No crew or flow found.")
for crew in crews:
if all:
crew.reset_memories(command_type="all")
@@ -59,6 +85,20 @@ def reset_memories_command(
f"[Crew ({crew.name if crew.name else crew.id})] Agents knowledge has been reset."
)
for flow in flows:
flow_name = flow.name or flow.__class__.__name__
if all:
_reset_flow_memory(flow)
click.echo(
f"[Flow ({flow_name})] Reset memories command has been completed."
)
continue
if memory:
_reset_flow_memory(flow)
click.echo(
f"[Flow ({flow_name})] Memory has been reset."
)
except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
click.echo(e.output, err=True)

View File

@@ -386,6 +386,109 @@ def fetch_crews(module_attr: Any) -> list[Crew]:
return crew_instances
def get_flow_instance(module_attr: Any) -> Flow | None:
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
Args:
module_attr: An attribute from a loaded module.
Returns:
A Flow instance if the attribute is a valid user-defined Flow subclass,
None otherwise.
"""
if (
isinstance(module_attr, type)
and issubclass(module_attr, Flow)
and module_attr is not Flow
):
try:
return module_attr()
except Exception:
return None
return None
_SKIP_DIRS = frozenset(
{".venv", "venv", ".git", "__pycache__", "node_modules", ".tox", ".nox"}
)
def get_flows(flow_path: str = "main.py") -> list[Flow]:
"""Get the flow instances from project files.
Walks the project directory looking for files matching ``flow_path``
(default ``main.py``), loads each module, and extracts Flow subclass
instances. Directories that are clearly not user source code (virtual
environments, ``.git``, etc.) are pruned to avoid noisy import errors.
Args:
flow_path: Filename to search for (default ``main.py``).
Returns:
A list of discovered Flow instances.
"""
flow_instances: list[Flow] = []
try:
current_dir = os.getcwd()
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
src_dir = os.path.join(current_dir, "src")
if os.path.isdir(src_dir) and src_dir not in sys.path:
sys.path.insert(0, src_dir)
search_paths = [".", "src"] if os.path.isdir("src") else ["."]
for search_path in search_paths:
for root, dirs, files in os.walk(search_path):
dirs[:] = [
d
for d in dirs
if d not in _SKIP_DIRS and not d.startswith(".")
]
if flow_path in files and "cli/templates" not in root:
file_os_path = os.path.join(root, flow_path)
try:
spec = importlib.util.spec_from_file_location(
"flow_module", file_os_path
)
if not spec or not spec.loader:
continue
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
try:
spec.loader.exec_module(module)
for attr_name in dir(module):
module_attr = getattr(module, attr_name)
try:
if flow_instance := get_flow_instance(
module_attr
):
flow_instances.append(flow_instance)
except Exception: # noqa: S112
continue
if flow_instances:
break
except Exception: # noqa: S112
continue
except (ImportError, AttributeError):
continue
if flow_instances:
break
except Exception: # noqa: S110
pass
return flow_instances
def is_valid_tool(obj: Any) -> bool:
from crewai.tools.base_tool import Tool

View File

@@ -1,7 +1,8 @@
from collections.abc import Generator
from contextlib import contextmanager, nullcontext
from contextlib import contextmanager
import contextvars
from typing import Any, ContextManager
import os
from typing import Any
_platform_integration_token: contextvars.ContextVar[str | None] = (
@@ -9,50 +10,40 @@ _platform_integration_token: contextvars.ContextVar[str | None] = (
)
def set_platform_integration_token(integration_token: str) -> contextvars.Token[str | None]:
def set_platform_integration_token(integration_token: str) -> None:
"""Set the platform integration token in the current context.
Args:
integration_token: The integration token to set.
"""
return _platform_integration_token.set(integration_token)
def reset_platform_integration_token(token: contextvars.Token[str | None]) -> None:
"""Reset the platform integration token to its previous value."""
_platform_integration_token.reset(token)
_platform_integration_token.set(integration_token)
def get_platform_integration_token() -> str | None:
"""Get the platform integration token from the current context.
"""Get the platform integration token from the current context or environment.
Returns:
The integration token if set, otherwise None.
"""
return _platform_integration_token.get()
token = _platform_integration_token.get()
if token is None:
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
return token
def platform_integration_context(integration_token: str | None) -> ContextManager[None]:
@contextmanager
def platform_context(integration_token: str) -> Generator[None, Any, None]:
"""Context manager to temporarily set the platform integration token.
Args:
integration_token: The integration token to set within the context.
If None or falsy, returns nullcontext (no-op).
Returns:
A context manager that either sets the token or does nothing.
"""
if not integration_token:
return nullcontext()
token = _platform_integration_token.set(integration_token)
try:
yield
finally:
_platform_integration_token.reset(token)
@contextmanager
def _token_context() -> Generator[None, Any, None]:
token = set_platform_integration_token(integration_token)
try:
yield
finally:
reset_platform_integration_token(token)
return _token_context()
_current_task_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"current_task_id", default=None

View File

@@ -120,6 +120,52 @@ class FlowPlotEvent(FlowEvent):
type: str = "flow_plot"
class FlowInputRequestedEvent(FlowEvent):
"""Event emitted when a flow requests user input via ``Flow.ask()``.
This event is emitted before the flow suspends waiting for user input,
allowing UI frameworks and observability tools to know when a flow
needs user interaction.
Attributes:
flow_name: Name of the flow requesting input.
method_name: Name of the flow method that called ``ask()``.
message: The question or prompt being shown to the user.
metadata: Optional metadata sent with the question (e.g., user ID,
channel, session context).
"""
method_name: str
message: str
metadata: dict[str, Any] | None = None
type: str = "flow_input_requested"
class FlowInputReceivedEvent(FlowEvent):
"""Event emitted when user input is received after ``Flow.ask()``.
This event is emitted after the user provides input (or the request
times out), allowing UI frameworks and observability tools to track
input collection.
Attributes:
flow_name: Name of the flow that received input.
method_name: Name of the flow method that called ``ask()``.
message: The original question or prompt.
response: The user's response, or None if timed out / unavailable.
metadata: Optional metadata sent with the question.
response_metadata: Optional metadata from the provider about the
response (e.g., who responded, thread ID, timestamps).
"""
method_name: str
message: str
response: str | None = None
metadata: dict[str, Any] | None = None
response_metadata: dict[str, Any] | None = None
type: str = "flow_input_received"
class HumanFeedbackRequestedEvent(FlowEvent):
"""Event emitted when human feedback is requested.

View File

@@ -7,6 +7,7 @@ from crewai.flow.async_feedback import (
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.flow.flow_config import flow_config
from crewai.flow.human_feedback import HumanFeedbackResult, human_feedback
from crewai.flow.input_provider import InputProvider, InputResponse
from crewai.flow.persistence import persist
from crewai.flow.visualization import (
FlowStructure,
@@ -22,6 +23,8 @@ __all__ = [
"HumanFeedbackPending",
"HumanFeedbackProvider",
"HumanFeedbackResult",
"InputProvider",
"InputResponse",
"PendingFeedbackContext",
"and_",
"build_flow_structure",

View File

@@ -1,7 +1,8 @@
"""Default provider implementations for human feedback.
"""Default provider implementations for human feedback and user input.
This module provides the ConsoleProvider, which is the default synchronous
provider that collects feedback via console input.
provider that collects both feedback (for ``@human_feedback``) and user input
(for ``Flow.ask()``) via console.
"""
from __future__ import annotations
@@ -16,20 +17,23 @@ if TYPE_CHECKING:
class ConsoleProvider:
"""Default synchronous console-based feedback provider.
"""Default synchronous console-based provider for feedback and input.
This provider blocks execution and waits for console input from the user.
It displays the method output with formatting and prompts for feedback.
It serves two purposes:
- **Feedback** (``request_feedback``): Used by ``@human_feedback`` to
display method output and collect review feedback.
- **Input** (``request_input``): Used by ``Flow.ask()`` to prompt the
user with a question and collect a response.
This is the default provider used when no custom provider is specified
in the @human_feedback decorator.
in the ``@human_feedback`` decorator or on the Flow's ``input_provider``.
Example:
Example (feedback):
```python
from crewai.flow.async_feedback import ConsoleProvider
# Explicitly use console provider
@human_feedback(
message="Review this:",
provider=ConsoleProvider(),
@@ -37,9 +41,20 @@ class ConsoleProvider:
def my_method(self):
return "Content to review"
```
Example (input):
```python
from crewai.flow import Flow, start
class MyFlow(Flow):
@start()
def gather_info(self):
topic = self.ask("What topic should we research?")
return topic
```
"""
def __init__(self, verbose: bool = True):
def __init__(self, verbose: bool = True) -> None:
"""Initialize the console provider.
Args:
@@ -124,3 +139,55 @@ class ConsoleProvider:
finally:
# Resume live updates
formatter.resume_live_updates()
def request_input(
self,
message: str,
flow: Flow[Any],
metadata: dict[str, Any] | None = None,
) -> str | None:
"""Request user input via console (blocking).
Displays the prompt message with formatting and waits for the user
to type their response. Used by ``Flow.ask()``.
Unlike ``request_feedback``, this method does not display an
"OUTPUT FOR REVIEW" panel or emit feedback-specific events (those
are handled by ``ask()`` itself).
Args:
message: The question or prompt to display to the user.
flow: The Flow instance requesting input.
metadata: Optional metadata from the caller. Ignored by the
console provider (console has no concept of user routing).
Returns:
The user's input as a stripped string. Returns empty string
if user presses Enter without input. Never returns None
(console input is always available).
"""
from crewai.events.event_listener import event_listener
# Pause live updates during human input
formatter = event_listener.formatter
formatter.pause_live_updates()
try:
console = formatter.console
if self.verbose:
console.print()
console.print(message, style="yellow")
console.print()
response = input(">>> \n").strip()
else:
response = input(f"{message} ").strip()
# Add line break after input so formatter output starts clean
console.print()
return response
finally:
# Resume live updates
formatter.resume_live_updates()

View File

@@ -77,7 +77,7 @@ from crewai.flow.flow_wrappers import (
StartMethod,
)
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey
from crewai.flow.types import FlowExecutionData, FlowMethodName, InputHistoryEntry, PendingListenerKey
from crewai.flow.utils import (
_extract_all_methods,
_extract_all_methods_recursive,
@@ -738,6 +738,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
tracing: bool | None = None
stream: bool = False
memory: Any = None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
class _FlowGeneric(cls): # type: ignore
@@ -784,6 +785,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._pending_feedback_context: PendingFeedbackContext | None = None
self.suppress_flow_events: bool = suppress_flow_events
# User input history (for self.ask())
self._input_history: list[InputHistoryEntry] = []
# Initialize state with initial values
self._state = self._create_initial_state()
self.tracing = tracing
@@ -2119,15 +2123,24 @@ class Flow(Generic[T], metaclass=FlowMeta):
if future:
self._event_futures.append(future)
if asyncio.iscoroutinefunction(method):
result = await method(*args, **kwargs)
else:
# Run sync methods in thread pool for isolation
# This allows Agent.kickoff() to work synchronously inside Flow methods
import contextvars
# Set method name in context so ask() can read it without
# stack inspection. Must happen before copy_context() so the
# value propagates into the thread pool for sync methods.
from crewai.flow.flow_context import current_flow_method_name
ctx = contextvars.copy_context()
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
method_name_token = current_flow_method_name.set(method_name)
try:
if asyncio.iscoroutinefunction(method):
result = await method(*args, **kwargs)
else:
# Run sync methods in thread pool for isolation
# This allows Agent.kickoff() to work synchronously inside Flow methods
import contextvars
ctx = contextvars.copy_context()
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
finally:
current_flow_method_name.reset(method_name_token)
# Auto-await coroutines returned from sync methods (enables AgentExecutor pattern)
if asyncio.iscoroutine(result):
@@ -2582,6 +2595,201 @@ class Flow(Generic[T], metaclass=FlowMeta):
logger.error(f"Error executing listener {listener_name}: {e}")
raise
# ── User Input (self.ask) ────────────────────────────────────────
def _resolve_input_provider(self) -> Any:
"""Resolve the input provider using the priority chain.
Resolution order:
1. ``self.input_provider`` (per-flow override)
2. ``flow_config.input_provider`` (global default)
3. ``ConsoleInputProvider()`` (built-in fallback)
Returns:
An object implementing the ``InputProvider`` protocol.
"""
from crewai.flow.async_feedback.providers import ConsoleProvider
from crewai.flow.flow_config import flow_config
if self.input_provider is not None:
return self.input_provider
if flow_config.input_provider is not None:
return flow_config.input_provider
return ConsoleProvider()
def _checkpoint_state_for_ask(self) -> None:
"""Auto-checkpoint flow state before waiting for user input.
If persistence is configured, saves the current state so that
``self.state`` is recoverable even if the process crashes while
waiting for input.
This is best-effort: if persistence is not configured, this is a no-op.
"""
if self._persistence is None:
return
try:
state_data = (
self._state
if isinstance(self._state, dict)
else self._state.model_dump()
)
self._persistence.save_state(
flow_uuid=self.flow_id,
method_name="_ask_checkpoint",
state_data=state_data,
)
except Exception:
logger.debug("Failed to checkpoint state before ask()", exc_info=True)
def ask(
self,
message: str,
timeout: float | None = None,
metadata: dict[str, Any] | None = None,
) -> str | None:
"""Request input from the user during flow execution.
Blocks the current thread until the user provides input or the
timeout expires. Works in both sync and async flow methods (the
flow framework runs sync methods in a thread pool via
``asyncio.to_thread``, so the event loop stays free).
Timeout ensures flows always terminate. When timeout expires,
``None`` is returned, enabling the pattern::
while (msg := self.ask("You: ", timeout=300)) is not None:
process(msg)
Before waiting for input, the current ``self.state`` is automatically
checkpointed to persistence (if configured) for durability.
Args:
message: The question or prompt to display to the user.
timeout: Maximum seconds to wait for input. ``None`` means
wait indefinitely. When timeout expires, returns ``None``.
Note: timeout is best-effort for the provider call --
``ask()`` returns ``None`` promptly, but the underlying
``request_input()`` may continue running in a background
thread until it completes naturally. Network providers
should implement their own internal timeouts.
metadata: Optional metadata to send to the input provider,
such as user ID, channel, session context. The provider
can use this to route the question to the right recipient.
Returns:
The user's input as a string, or ``None`` on timeout, disconnect,
or provider error. Empty string ``""`` means the user pressed
Enter without typing (intentional empty input).
Example:
```python
class MyFlow(Flow):
@start()
def gather_info(self):
topic = self.ask(
"What topic should we research?",
metadata={"user_id": "u123", "channel": "#research"},
)
if topic is None:
return "No input received"
return topic
```
"""
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
from datetime import datetime
from crewai.events.types.flow_events import (
FlowInputReceivedEvent,
FlowInputRequestedEvent,
)
from crewai.flow.flow_context import current_flow_method_name
from crewai.flow.input_provider import InputResponse
method_name = current_flow_method_name.get("unknown")
# Emit input requested event
crewai_event_bus.emit(
self,
FlowInputRequestedEvent(
type="flow_input_requested",
flow_name=self.name or self.__class__.__name__,
method_name=method_name,
message=message,
metadata=metadata,
),
)
# Auto-checkpoint state before waiting
self._checkpoint_state_for_ask()
provider = self._resolve_input_provider()
raw: str | InputResponse | None = None
try:
if timeout is not None:
# Manual executor management to avoid shutdown(wait=True)
# deadlock when the provider call outlives the timeout.
executor = ThreadPoolExecutor(max_workers=1)
future = executor.submit(
provider.request_input, message, self, metadata
)
try:
raw = future.result(timeout=timeout)
except FuturesTimeoutError:
future.cancel()
raw = None
finally:
# wait=False so we don't block if the provider is still
# running (e.g. input() stuck waiting for user).
# cancel_futures=True cleans up any queued-but-not-started tasks.
executor.shutdown(wait=False, cancel_futures=True)
else:
raw = provider.request_input(message, self, metadata=metadata)
except KeyboardInterrupt:
raise
except Exception:
logger.debug("Input provider error in ask()", exc_info=True)
raw = None
# Normalize provider response: str, InputResponse, or None
response: str | None = None
response_metadata: dict[str, Any] | None = None
if isinstance(raw, InputResponse):
response = raw.text
response_metadata = raw.metadata
elif isinstance(raw, str):
response = raw
else:
response = None
# Record in history
self._input_history.append({
"message": message,
"response": response,
"method_name": method_name,
"timestamp": datetime.now(),
"metadata": metadata,
"response_metadata": response_metadata,
})
# Emit input received event
crewai_event_bus.emit(
self,
FlowInputReceivedEvent(
type="flow_input_received",
flow_name=self.name or self.__class__.__name__,
method_name=method_name,
message=message,
response=response,
metadata=metadata,
response_metadata=response_metadata,
),
)
return response
def _request_human_feedback(
self,
message: str,

View File

@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from crewai.flow.async_feedback.types import HumanFeedbackProvider
from crewai.flow.input_provider import InputProvider
class FlowConfig:
@@ -20,10 +21,15 @@ class FlowConfig:
hitl_provider: The human-in-the-loop feedback provider.
Defaults to None (uses console input).
Can be overridden by deployments at startup.
input_provider: The input provider used by ``Flow.ask()``.
Defaults to None (uses ``ConsoleProvider``).
Can be overridden by
deployments at startup.
"""
def __init__(self) -> None:
self._hitl_provider: HumanFeedbackProvider | None = None
self._input_provider: InputProvider | None = None
@property
def hitl_provider(self) -> Any:
@@ -35,6 +41,32 @@ class FlowConfig:
"""Set the HITL provider."""
self._hitl_provider = provider
@property
def input_provider(self) -> Any:
"""Get the configured input provider for ``Flow.ask()``.
Returns:
The configured InputProvider instance, or None if not set
(in which case ``ConsoleInputProvider`` is used as default).
"""
return self._input_provider
@input_provider.setter
def input_provider(self, provider: Any) -> None:
"""Set the input provider for ``Flow.ask()``.
Args:
provider: An object implementing the ``InputProvider`` protocol.
Example:
```python
from crewai.flow import flow_config
flow_config.input_provider = WebSocketInputProvider(...)
```
"""
self._input_provider = provider
# Singleton instance
flow_config = FlowConfig()

View File

@@ -14,3 +14,7 @@ current_flow_request_id: contextvars.ContextVar[str | None] = contextvars.Contex
current_flow_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"flow_id", default=None
)
current_flow_method_name: contextvars.ContextVar[str] = contextvars.ContextVar(
"flow_method_name", default="unknown"
)

View File

@@ -0,0 +1,151 @@
"""Input provider protocol for Flow.ask().
This module provides the InputProvider protocol and InputResponse dataclass
used by Flow.ask() to request input from users during flow execution.
The default implementation is ``ConsoleProvider`` (from
``crewai.flow.async_feedback.providers``), which serves both feedback
and input collection via console.
Example (default console input):
```python
from crewai.flow import Flow, start
class MyFlow(Flow):
@start()
def gather_info(self):
topic = self.ask("What topic should we research?")
return topic
```
Example (custom provider with metadata):
```python
from crewai.flow import Flow, start
from crewai.flow.input_provider import InputProvider, InputResponse
class SlackProvider:
def request_input(self, message, flow, metadata=None):
channel = metadata.get("channel", "#general") if metadata else "#general"
thread = self.post_question(channel, message)
reply = self.wait_for_reply(thread)
return InputResponse(
text=reply.text,
metadata={"responded_by": reply.user_id, "thread_id": thread.id},
)
class MyFlow(Flow):
input_provider = SlackProvider()
@start()
def gather_info(self):
topic = self.ask("What topic?", metadata={"channel": "#research"})
return topic
```
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
if TYPE_CHECKING:
from crewai.flow.flow import Flow
@dataclass
class InputResponse:
"""Response from an InputProvider, optionally carrying metadata.
Simple providers can just return a string from ``request_input()``.
Providers that need to send metadata back (e.g., who responded,
thread ID, external timestamps) return an ``InputResponse`` instead.
``ask()`` normalizes both cases -- callers always get ``str | None``.
The response metadata is stored in ``_input_history`` and emitted
in ``FlowInputReceivedEvent``.
Attributes:
text: The user's input text, or None if unavailable.
metadata: Optional metadata from the provider about the response
(e.g., who responded, thread ID, timestamps).
Example:
```python
class MyProvider:
def request_input(self, message, flow, metadata=None):
response = get_response_from_external_system(message)
return InputResponse(
text=response.text,
metadata={"responded_by": response.user_id},
)
```
"""
text: str | None
metadata: dict[str, Any] | None = field(default=None)
@runtime_checkable
class InputProvider(Protocol):
"""Protocol for user input collection strategies.
Implement this protocol to create custom input providers that integrate
with external systems like websockets, web UIs, Slack, or custom APIs.
The default provider is ``ConsoleProvider``, which blocks waiting for
console input via Python's built-in ``input()`` function.
Providers are always synchronous. The flow framework runs sync methods
in a thread pool (via ``asyncio.to_thread``), so ``ask()`` never blocks
the event loop even inside async flow methods.
Providers can return either:
- ``str | None`` for simple cases (no response metadata)
- ``InputResponse`` when they need to send metadata back with the answer
Example (simple):
```python
class SimpleProvider:
def request_input(self, message: str, flow: Flow) -> str | None:
return input(message)
```
Example (with metadata):
```python
class SlackProvider:
def request_input(self, message, flow, metadata=None):
channel = metadata.get("channel") if metadata else "#general"
reply = self.post_and_wait(channel, message)
return InputResponse(
text=reply.text,
metadata={"responded_by": reply.user_id},
)
```
"""
def request_input(
self,
message: str,
flow: Flow[Any],
metadata: dict[str, Any] | None = None,
) -> str | InputResponse | None:
"""Request input from the user.
Args:
message: The question or prompt to display to the user.
flow: The Flow instance requesting input. Can be used to
access flow state, name, or other context.
metadata: Optional metadata from the caller, such as user ID,
channel, session context, etc. Providers can use this to
route the question to the right recipient.
Returns:
The user's input as a string, an ``InputResponse`` with text
and optional response metadata, or None if input is unavailable
(e.g., user cancelled, connection dropped).
"""
...

View File

@@ -4,6 +4,7 @@ This module contains TypedDict definitions and type aliases used throughout
the Flow system.
"""
from datetime import datetime
from typing import (
Annotated,
Any,
@@ -101,6 +102,30 @@ class FlowData(TypedDict):
flow_methods_attributes: list[FlowMethodData]
class InputHistoryEntry(TypedDict):
"""A single entry in the flow's input history from ``self.ask()``.
Each call to ``Flow.ask()`` appends one entry recording the question,
the user's response, which method asked, and any metadata exchanged
between the caller and the input provider.
Attributes:
message: The question or prompt that was displayed to the user.
response: The user's response, or None on timeout/error.
method_name: The flow method that called ``ask()``.
timestamp: When the input was received.
metadata: Metadata sent with the question (caller to provider).
response_metadata: Metadata received with the answer (provider to caller).
"""
message: str
response: str | None
method_name: str
timestamp: datetime
metadata: dict[str, Any] | None
response_metadata: dict[str, Any] | None
class FlowExecutionData(TypedDict):
"""Flow execution data.

View File

@@ -66,7 +66,9 @@ def mock_crew():
def mock_get_crews(mock_crew):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
) as mock_get_crew:
) as mock_get_crew, mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[]
):
yield mock_get_crew
@@ -193,6 +195,79 @@ def test_reset_memory_from_many_crews(mock_get_crews, runner):
assert call_count == 2, "reset_memories should have been called twice"
@pytest.fixture
def mock_flow():
_mock = mock.Mock()
_mock.name = "TestFlow"
_mock.memory = mock.Mock()
_mock.memory.reset = mock.Mock()
return _mock
@pytest.fixture
def mock_get_flows(mock_flow):
with mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
) as mock_get_flow, mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[]
):
yield mock_get_flow
def test_reset_flow_memory(mock_get_flows, mock_flow, runner):
result = runner.invoke(reset_memories, ["-m"])
mock_flow.memory.reset.assert_called_once()
assert "[Flow (TestFlow)] Memory has been reset." in result.output
def test_reset_flow_all_memories(mock_get_flows, mock_flow, runner):
result = runner.invoke(reset_memories, ["-a"])
mock_flow.memory.reset.assert_called_once()
assert "[Flow (TestFlow)] Reset memories command has been completed." in result.output
def test_reset_flow_knowledge_no_effect(mock_get_flows, mock_flow, runner):
result = runner.invoke(reset_memories, ["--knowledge"])
mock_flow.memory.reset.assert_not_called()
assert "[Flow (TestFlow)]" not in result.output
def test_reset_no_crew_or_flow_found(runner):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[]
), mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[]
):
result = runner.invoke(reset_memories, ["-m"])
assert "No crew or flow found." in result.output
def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
), mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
):
result = runner.invoke(reset_memories, ["-m"])
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
mock_flow.memory.reset.assert_called_once()
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
assert "[Flow (TestFlow)] Memory has been reset." in result.output
def test_reset_flow_memory_none(runner):
mock_flow = mock.Mock()
mock_flow.name = "NoMemFlow"
mock_flow.memory = None
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[]
), mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
):
result = runner.invoke(reset_memories, ["-m"])
assert "[Flow (NoMemFlow)] Memory has been reset." in result.output
def test_reset_no_memory_flags(runner):
result = runner.invoke(
reset_memories,

View File

@@ -7,139 +7,215 @@ import pytest
from crewai.context import (
_platform_integration_token,
get_platform_integration_token,
platform_integration_context,
reset_platform_integration_token,
platform_context,
set_platform_integration_token,
)
@pytest.fixture
def clean_context():
"""Fixture to ensure clean context state for each test."""
_platform_integration_token.set(None)
yield
_platform_integration_token.set(None)
class TestPlatformIntegrationToken:
def setup_method(self):
_platform_integration_token.set(None)
def teardown_method(self):
_platform_integration_token.set(None)
class TestContextVariableCore:
"""Test core context variable functionality (set/get/reset)."""
def test_set_and_get_token(self, clean_context):
"""Test basic token setting and retrieval."""
@patch.dict(os.environ, {}, clear=True)
def test_set_platform_integration_token(self):
test_token = "test-token-123"
assert get_platform_integration_token() is None
context_token = set_platform_integration_token(test_token)
set_platform_integration_token(test_token)
assert get_platform_integration_token() == test_token
assert context_token is not None
def test_reset_token_restores_previous_state(self, clean_context):
"""Test that reset properly restores previous context state."""
token1 = "token-1"
token2 = "token-2"
def test_get_platform_integration_token_from_context_var(self):
test_token = "context-var-token"
context_token1 = set_platform_integration_token(token1)
assert get_platform_integration_token() == token1
_platform_integration_token.set(test_token)
context_token2 = set_platform_integration_token(token2)
assert get_platform_integration_token() == token2
assert get_platform_integration_token() == test_token
reset_platform_integration_token(context_token2)
assert get_platform_integration_token() == token1
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token-456"})
def test_get_platform_integration_token_from_env_var(self):
assert _platform_integration_token.get() is None
reset_platform_integration_token(context_token1)
assert get_platform_integration_token() is None
def test_nested_token_management(self, clean_context):
"""Test proper token management with deeply nested contexts."""
tokens = ["token-1", "token-2", "token-3"]
context_tokens = []
for token in tokens:
context_tokens.append(set_platform_integration_token(token))
assert get_platform_integration_token() == token
for i in range(len(tokens) - 1, 0, -1):
reset_platform_integration_token(context_tokens[i])
assert get_platform_integration_token() == tokens[i - 1]
reset_platform_integration_token(context_tokens[0])
assert get_platform_integration_token() is None
assert get_platform_integration_token() == "env-token-456"
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token"})
def test_context_module_ignores_environment_variables(self, clean_context):
"""Test that context module only returns context values, not env vars."""
# Context module should not read environment variables
assert get_platform_integration_token() is None
def test_context_var_takes_precedence_over_env_var(self):
context_token = "context-token"
# Only context variable should be returned
set_platform_integration_token("context-token")
assert get_platform_integration_token() == "context-token"
set_platform_integration_token(context_token)
assert get_platform_integration_token() == context_token
class TestPlatformIntegrationContext:
"""Test platform integration context manager behavior."""
def test_basic_context_manager_usage(self, clean_context):
"""Test basic context manager functionality."""
test_token = "context-token"
@patch.dict(os.environ, {}, clear=True)
def test_get_platform_integration_token_returns_none_when_not_set(self):
assert _platform_integration_token.get() is None
assert get_platform_integration_token() is None
with platform_integration_context(test_token):
@patch.dict(os.environ, {}, clear=True)
def test_platform_context_manager_basic_usage(self):
test_token = "context-manager-token"
assert get_platform_integration_token() is None
with platform_context(test_token):
assert get_platform_integration_token() == test_token
assert get_platform_integration_token() is None
@pytest.mark.parametrize("falsy_value", [None, "", False, 0])
def test_falsy_values_return_nullcontext(self, clean_context, falsy_value):
"""Test that falsy values return nullcontext (no-op)."""
# Set initial token to verify nullcontext doesn't affect it
initial_token = "initial-token"
initial_context_token = set_platform_integration_token(initial_token)
@patch.dict(os.environ, {}, clear=True)
def test_platform_context_manager_nested_contexts(self):
"""Test nested platform_context context managers."""
outer_token = "outer-token"
inner_token = "inner-token"
try:
with platform_integration_context(falsy_value):
# Should preserve existing context (nullcontext behavior)
assert get_platform_integration_token() == initial_token
# Should still have initial token after nullcontext
assert get_platform_integration_token() == initial_token
finally:
reset_platform_integration_token(initial_context_token)
@pytest.mark.parametrize("truthy_value", ["token", "123", " ", "0"])
def test_truthy_values_create_context(self, clean_context, truthy_value):
"""Test that truthy values create proper context."""
with platform_integration_context(truthy_value):
assert get_platform_integration_token() == truthy_value
# Should be cleaned up
assert get_platform_integration_token() is None
def test_context_preserves_existing_token(self, clean_context):
"""Test that context manager preserves existing token when exiting."""
existing_token = "existing-token"
with platform_context(outer_token):
assert get_platform_integration_token() == outer_token
with platform_context(inner_token):
assert get_platform_integration_token() == inner_token
assert get_platform_integration_token() == outer_token
assert get_platform_integration_token() is None
def test_platform_context_manager_preserves_existing_token(self):
"""Test that platform_context preserves existing token when exiting."""
initial_token = "initial-token"
context_token = "context-token"
existing_context_token = set_platform_integration_token(existing_token)
set_platform_integration_token(initial_token)
assert get_platform_integration_token() == initial_token
try:
with platform_integration_context(context_token):
with platform_context(context_token):
assert get_platform_integration_token() == context_token
assert get_platform_integration_token() == initial_token
def test_platform_context_manager_exception_handling(self):
"""Test that platform_context properly resets token even when exception occurs."""
initial_token = "initial-token"
context_token = "context-token"
set_platform_integration_token(initial_token)
with pytest.raises(ValueError):
with platform_context(context_token):
assert get_platform_integration_token() == context_token
raise ValueError("Test exception")
assert get_platform_integration_token() == existing_token
finally:
reset_platform_integration_token(existing_context_token)
assert get_platform_integration_token() == initial_token
def test_context_manager_return_type(self, clean_context):
"""Test that context manager returns proper types for both cases."""
# Both should be usable as context managers
valid_ctx = platform_integration_context("token")
none_ctx = platform_integration_context(None)
@patch.dict(os.environ, {}, clear=True)
def test_platform_context_manager_with_none_initial_state(self):
"""Test platform_context when initial state is None."""
context_token = "context-token"
assert hasattr(valid_ctx, '__enter__')
assert hasattr(valid_ctx, '__exit__')
assert hasattr(none_ctx, '__enter__')
assert hasattr(none_ctx, '__exit__')
assert get_platform_integration_token() is None
with pytest.raises(RuntimeError):
with platform_context(context_token):
assert get_platform_integration_token() == context_token
raise RuntimeError("Test exception")
assert get_platform_integration_token() is None
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-backup"})
def test_platform_context_with_env_fallback(self):
"""Test platform_context interaction with environment variable fallback."""
context_token = "context-token"
assert get_platform_integration_token() == "env-backup"
with platform_context(context_token):
assert get_platform_integration_token() == context_token
assert get_platform_integration_token() == "env-backup"
@patch.dict(os.environ, {}, clear=True)
def test_multiple_sequential_context_managers(self):
"""Test multiple sequential uses of platform_context."""
token1 = "token-1"
token2 = "token-2"
token3 = "token-3"
with platform_context(token1):
assert get_platform_integration_token() == token1
assert get_platform_integration_token() is None
with platform_context(token2):
assert get_platform_integration_token() == token2
assert get_platform_integration_token() is None
with platform_context(token3):
assert get_platform_integration_token() == token3
assert get_platform_integration_token() is None
def test_empty_string_token(self):
empty_token = ""
set_platform_integration_token(empty_token)
assert get_platform_integration_token() == ""
with platform_context(empty_token):
assert get_platform_integration_token() == ""
def test_special_characters_in_token(self):
special_token = "token-with-!@#$%^&*()_+-={}[]|\\:;\"'<>?,./"
set_platform_integration_token(special_token)
assert get_platform_integration_token() == special_token
with platform_context(special_token):
assert get_platform_integration_token() == special_token
def test_very_long_token(self):
long_token = "a" * 10000
set_platform_integration_token(long_token)
assert get_platform_integration_token() == long_token
with platform_context(long_token):
assert get_platform_integration_token() == long_token
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": ""})
def test_empty_env_var(self):
assert _platform_integration_token.get() is None
assert get_platform_integration_token() == ""
@patch("crewai.context.os.getenv")
def test_env_var_access_error_handling(self, mock_getenv):
mock_getenv.side_effect = OSError("Environment access error")
with pytest.raises(OSError):
get_platform_integration_token()
@patch.dict(os.environ, {}, clear=True)
def test_context_var_isolation_between_tests(self):
"""Test that context variable changes don't leak between test methods."""
test_token = "isolation-test-token"
assert get_platform_integration_token() is None
set_platform_integration_token(test_token)
assert get_platform_integration_token() == test_token
def test_context_manager_return_value(self):
"""Test that platform_context can be used in with statement with return value."""
test_token = "return-value-token"
with platform_context(test_token):
assert get_platform_integration_token() == test_token
with platform_context(test_token) as ctx:
assert ctx is None
assert get_platform_integration_token() == test_token

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,7 @@ from unittest.mock import MagicMock, patch
import pytest
from pydantic import BaseModel
from crewai.flow import Flow, HumanFeedbackResult, human_feedback, listen, start
from crewai.flow import Flow, HumanFeedbackResult, human_feedback, listen, or_, start
from crewai.flow.flow import FlowState
@@ -271,6 +271,182 @@ class TestMultiStepFlows:
assert len(flow.human_feedback_history) == 1
assert flow.human_feedback_history[0].outcome == "rejected"
def test_hitl_self_loop_routes_back_to_same_method(self):
"""Test that a HITL router can loop back to itself via its own emit outcome.
Pattern: review_work listens to or_("do_work", "review") and emits
["review", "approved"]. When the human rejects (outcome="review"),
the method should re-execute. When approved, the flow should continue
to the approve_work listener.
"""
execution_order: list[str] = []
class SelfLoopFlow(Flow):
@start()
def initial_func(self):
execution_order.append("initial_func")
return "initial"
@listen(initial_func)
def do_work(self):
execution_order.append("do_work")
return "work output"
@human_feedback(
message="Do you approve this content?",
emit=["review", "approved"],
llm="gpt-4o-mini",
default_outcome="approved",
)
@listen(or_("do_work", "review"))
def review_work(self):
execution_order.append("review_work")
return "content for review"
@listen("approved")
def approve_work(self):
execution_order.append("approve_work")
return "published"
flow = SelfLoopFlow()
# First call: human rejects (outcome="review") -> self-loop
# Second call: human approves (outcome="approved") -> continue
with (
patch.object(
flow,
"_request_human_feedback",
side_effect=["needs changes", "looks good"],
),
patch.object(
flow,
"_collapse_to_outcome",
side_effect=["review", "approved"],
),
):
result = flow.kickoff()
assert execution_order == [
"initial_func",
"do_work",
"review_work", # first review -> rejected (review)
"review_work", # second review -> approved
"approve_work",
]
assert result == "published"
assert len(flow.human_feedback_history) == 2
assert flow.human_feedback_history[0].outcome == "review"
assert flow.human_feedback_history[1].outcome == "approved"
def test_hitl_self_loop_multiple_rejections(self):
"""Test that a HITL router can loop back multiple times before approving.
Verifies the self-loop works for more than one rejection cycle.
"""
execution_order: list[str] = []
class MultiRejectFlow(Flow):
@start()
def generate(self):
execution_order.append("generate")
return "draft"
@human_feedback(
message="Review this content:",
emit=["revise", "approved"],
llm="gpt-4o-mini",
default_outcome="approved",
)
@listen(or_("generate", "revise"))
def review(self):
execution_order.append("review")
return "content v" + str(execution_order.count("review"))
@listen("approved")
def publish(self):
execution_order.append("publish")
return "published"
flow = MultiRejectFlow()
# Three rejections, then approval
with (
patch.object(
flow,
"_request_human_feedback",
side_effect=["bad", "still bad", "not yet", "great"],
),
patch.object(
flow,
"_collapse_to_outcome",
side_effect=["revise", "revise", "revise", "approved"],
),
):
result = flow.kickoff()
assert execution_order == [
"generate",
"review", # 1st review -> revise
"review", # 2nd review -> revise
"review", # 3rd review -> revise
"review", # 4th review -> approved
"publish",
]
assert result == "published"
assert len(flow.human_feedback_history) == 4
assert [r.outcome for r in flow.human_feedback_history] == [
"revise", "revise", "revise", "approved"
]
def test_hitl_self_loop_immediate_approval(self):
"""Test that a HITL self-loop flow works when approved on the first try.
No looping occurs -- the flow should proceed straight through.
"""
execution_order: list[str] = []
class ImmediateApprovalFlow(Flow):
@start()
def generate(self):
execution_order.append("generate")
return "perfect draft"
@human_feedback(
message="Review:",
emit=["revise", "approved"],
llm="gpt-4o-mini",
)
@listen(or_("generate", "revise"))
def review(self):
execution_order.append("review")
return "content"
@listen("approved")
def publish(self):
execution_order.append("publish")
return "published"
flow = ImmediateApprovalFlow()
with (
patch.object(
flow,
"_request_human_feedback",
return_value="perfect",
),
patch.object(
flow,
"_collapse_to_outcome",
return_value="approved",
),
):
result = flow.kickoff()
assert execution_order == ["generate", "review", "publish"]
assert result == "published"
assert len(flow.human_feedback_history) == 1
assert flow.human_feedback_history[0].outcome == "approved"
def test_router_and_non_router_listeners_for_same_outcome(self):
"""Test that both router and non-router listeners fire for the same outcome."""
execution_order: list[str] = []