mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-02-18 03:48:14 +00:00
Compare commits
2 Commits
lg-isolate
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
84d57c7a24 | ||
|
|
4aedd58829 |
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
|
||||
from crewai.context import get_platform_integration_token as _get_context_token
|
||||
|
||||
|
||||
def get_platform_api_base_url() -> str:
|
||||
"""Get the platform API base URL from environment or use default."""
|
||||
@@ -10,16 +8,10 @@ def get_platform_api_base_url() -> str:
|
||||
|
||||
|
||||
def get_platform_integration_token() -> str:
|
||||
"""Get the platform integration token from the context.
|
||||
Fallback to the environment variable if no token has been set in the context.
|
||||
|
||||
Raises:
|
||||
ValueError: If no token has been set in the context.
|
||||
"""
|
||||
token = _get_context_token() or os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
|
||||
"""Get the platform API base URL from environment or use default."""
|
||||
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN") or ""
|
||||
if not token:
|
||||
raise ValueError(
|
||||
"No platform integration token found. "
|
||||
"Set it via platform_integration_context() or set_platform_integration_token()."
|
||||
"No platform integration token found, please set the CREWAI_PLATFORM_INTEGRATION_TOKEN environment variable"
|
||||
)
|
||||
return token
|
||||
return token # TODO: Use context manager to get token
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
"""Tests for platform tools misc functionality."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from crewai.context import platform_integration_context
|
||||
from crewai_tools.tools.crewai_platform_tools.misc import (
|
||||
get_platform_integration_token,
|
||||
)
|
||||
|
||||
|
||||
|
||||
class TestTokenRetrievalWithFallback:
|
||||
"""Test token retrieval logic with environment fallback."""
|
||||
|
||||
def test_context_token_takes_precedence(self, clean_context):
|
||||
"""Test that context token takes precedence over environment variable."""
|
||||
context_token = "context-token"
|
||||
env_token = "env-token"
|
||||
|
||||
with patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": env_token}):
|
||||
with platform_integration_context(context_token):
|
||||
token = get_platform_integration_token()
|
||||
assert token == context_token
|
||||
|
||||
def test_environment_fallback_when_no_context(self, clean_context):
|
||||
"""Test fallback to environment variable when no context token."""
|
||||
env_token = "env-fallback-token"
|
||||
|
||||
with patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": env_token}):
|
||||
token = get_platform_integration_token()
|
||||
assert token == env_token
|
||||
|
||||
@pytest.mark.parametrize("empty_value", ["", None])
|
||||
def test_missing_token_raises_error(self, clean_context, empty_value):
|
||||
"""Test that missing tokens raise appropriate errors."""
|
||||
env_dict = {"CREWAI_PLATFORM_INTEGRATION_TOKEN": empty_value} if empty_value is not None else {}
|
||||
|
||||
with patch.dict(os.environ, env_dict, clear=True):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_platform_integration_token()
|
||||
|
||||
assert "No platform integration token found" in str(exc_info.value)
|
||||
assert "platform_integration_context()" in str(exc_info.value)
|
||||
@@ -20117,6 +20117,18 @@
|
||||
"humanized_name": "Web Automation Tool",
|
||||
"init_params_schema": {
|
||||
"$defs": {
|
||||
"AvailableModel": {
|
||||
"enum": [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"computer-use-preview",
|
||||
"gemini-2.0-flash"
|
||||
],
|
||||
"title": "AvailableModel",
|
||||
"type": "string"
|
||||
},
|
||||
"EnvVar": {
|
||||
"properties": {
|
||||
"default": {
|
||||
@@ -20194,6 +20206,17 @@
|
||||
"default": null,
|
||||
"title": "Model Api Key"
|
||||
},
|
||||
"model_name": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/$defs/AvailableModel"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"default": "claude-3-7-sonnet-latest"
|
||||
},
|
||||
"project_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
|
||||
@@ -2,7 +2,30 @@ import subprocess
|
||||
|
||||
import click
|
||||
|
||||
from crewai.cli.utils import get_crews
|
||||
from crewai.cli.utils import get_crews, get_flows
|
||||
from crewai.flow import Flow
|
||||
|
||||
|
||||
def _reset_flow_memory(flow: Flow) -> None:
|
||||
"""Reset memory for a single flow instance.
|
||||
|
||||
Handles Memory, MemoryScope (both have .reset()), and MemorySlice
|
||||
(delegates to the underlying ._memory). Silently succeeds when the
|
||||
storage directory does not exist yet (nothing to reset).
|
||||
|
||||
Args:
|
||||
flow: The flow instance whose memory should be reset.
|
||||
"""
|
||||
mem = flow.memory
|
||||
if mem is None:
|
||||
return
|
||||
try:
|
||||
if hasattr(mem, "reset"):
|
||||
mem.reset()
|
||||
elif hasattr(mem, "_memory") and hasattr(mem._memory, "reset"):
|
||||
mem._memory.reset()
|
||||
except (FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
@@ -12,7 +35,7 @@ def reset_memories_command(
|
||||
kickoff_outputs: bool,
|
||||
all: bool,
|
||||
) -> None:
|
||||
"""Reset the crew memories.
|
||||
"""Reset the crew and flow memories.
|
||||
|
||||
Args:
|
||||
memory: Whether to reset the unified memory.
|
||||
@@ -29,8 +52,11 @@ def reset_memories_command(
|
||||
return
|
||||
|
||||
crews = get_crews()
|
||||
if not crews:
|
||||
raise ValueError("No crew found.")
|
||||
flows = get_flows()
|
||||
|
||||
if not crews and not flows:
|
||||
raise ValueError("No crew or flow found.")
|
||||
|
||||
for crew in crews:
|
||||
if all:
|
||||
crew.reset_memories(command_type="all")
|
||||
@@ -59,6 +85,20 @@ def reset_memories_command(
|
||||
f"[Crew ({crew.name if crew.name else crew.id})] Agents knowledge has been reset."
|
||||
)
|
||||
|
||||
for flow in flows:
|
||||
flow_name = flow.name or flow.__class__.__name__
|
||||
if all:
|
||||
_reset_flow_memory(flow)
|
||||
click.echo(
|
||||
f"[Flow ({flow_name})] Reset memories command has been completed."
|
||||
)
|
||||
continue
|
||||
if memory:
|
||||
_reset_flow_memory(flow)
|
||||
click.echo(
|
||||
f"[Flow ({flow_name})] Memory has been reset."
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
|
||||
click.echo(e.output, err=True)
|
||||
|
||||
@@ -386,6 +386,109 @@ def fetch_crews(module_attr: Any) -> list[Crew]:
|
||||
return crew_instances
|
||||
|
||||
|
||||
def get_flow_instance(module_attr: Any) -> Flow | None:
|
||||
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
|
||||
|
||||
Args:
|
||||
module_attr: An attribute from a loaded module.
|
||||
|
||||
Returns:
|
||||
A Flow instance if the attribute is a valid user-defined Flow subclass,
|
||||
None otherwise.
|
||||
"""
|
||||
if (
|
||||
isinstance(module_attr, type)
|
||||
and issubclass(module_attr, Flow)
|
||||
and module_attr is not Flow
|
||||
):
|
||||
try:
|
||||
return module_attr()
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
_SKIP_DIRS = frozenset(
|
||||
{".venv", "venv", ".git", "__pycache__", "node_modules", ".tox", ".nox"}
|
||||
)
|
||||
|
||||
|
||||
def get_flows(flow_path: str = "main.py") -> list[Flow]:
|
||||
"""Get the flow instances from project files.
|
||||
|
||||
Walks the project directory looking for files matching ``flow_path``
|
||||
(default ``main.py``), loads each module, and extracts Flow subclass
|
||||
instances. Directories that are clearly not user source code (virtual
|
||||
environments, ``.git``, etc.) are pruned to avoid noisy import errors.
|
||||
|
||||
Args:
|
||||
flow_path: Filename to search for (default ``main.py``).
|
||||
|
||||
Returns:
|
||||
A list of discovered Flow instances.
|
||||
"""
|
||||
flow_instances: list[Flow] = []
|
||||
try:
|
||||
current_dir = os.getcwd()
|
||||
if current_dir not in sys.path:
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
src_dir = os.path.join(current_dir, "src")
|
||||
if os.path.isdir(src_dir) and src_dir not in sys.path:
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
search_paths = [".", "src"] if os.path.isdir("src") else ["."]
|
||||
|
||||
for search_path in search_paths:
|
||||
for root, dirs, files in os.walk(search_path):
|
||||
dirs[:] = [
|
||||
d
|
||||
for d in dirs
|
||||
if d not in _SKIP_DIRS and not d.startswith(".")
|
||||
]
|
||||
if flow_path in files and "cli/templates" not in root:
|
||||
file_os_path = os.path.join(root, flow_path)
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"flow_module", file_os_path
|
||||
)
|
||||
if not spec or not spec.loader:
|
||||
continue
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
for attr_name in dir(module):
|
||||
module_attr = getattr(module, attr_name)
|
||||
try:
|
||||
if flow_instance := get_flow_instance(
|
||||
module_attr
|
||||
):
|
||||
flow_instances.append(flow_instance)
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
if flow_instances:
|
||||
break
|
||||
|
||||
except Exception: # noqa: S112
|
||||
continue
|
||||
|
||||
except (ImportError, AttributeError):
|
||||
continue
|
||||
|
||||
if flow_instances:
|
||||
break
|
||||
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
return flow_instances
|
||||
|
||||
|
||||
def is_valid_tool(obj: Any) -> bool:
|
||||
from crewai.tools.base_tool import Tool
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import contextmanager
|
||||
import contextvars
|
||||
from typing import Any, ContextManager
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
_platform_integration_token: contextvars.ContextVar[str | None] = (
|
||||
@@ -9,50 +10,40 @@ _platform_integration_token: contextvars.ContextVar[str | None] = (
|
||||
)
|
||||
|
||||
|
||||
def set_platform_integration_token(integration_token: str) -> contextvars.Token[str | None]:
|
||||
def set_platform_integration_token(integration_token: str) -> None:
|
||||
"""Set the platform integration token in the current context.
|
||||
|
||||
Args:
|
||||
integration_token: The integration token to set.
|
||||
"""
|
||||
return _platform_integration_token.set(integration_token)
|
||||
|
||||
|
||||
def reset_platform_integration_token(token: contextvars.Token[str | None]) -> None:
|
||||
"""Reset the platform integration token to its previous value."""
|
||||
_platform_integration_token.reset(token)
|
||||
_platform_integration_token.set(integration_token)
|
||||
|
||||
|
||||
def get_platform_integration_token() -> str | None:
|
||||
"""Get the platform integration token from the current context.
|
||||
"""Get the platform integration token from the current context or environment.
|
||||
|
||||
Returns:
|
||||
The integration token if set, otherwise None.
|
||||
"""
|
||||
return _platform_integration_token.get()
|
||||
token = _platform_integration_token.get()
|
||||
if token is None:
|
||||
token = os.getenv("CREWAI_PLATFORM_INTEGRATION_TOKEN")
|
||||
return token
|
||||
|
||||
|
||||
def platform_integration_context(integration_token: str | None) -> ContextManager[None]:
|
||||
@contextmanager
|
||||
def platform_context(integration_token: str) -> Generator[None, Any, None]:
|
||||
"""Context manager to temporarily set the platform integration token.
|
||||
|
||||
Args:
|
||||
integration_token: The integration token to set within the context.
|
||||
If None or falsy, returns nullcontext (no-op).
|
||||
|
||||
Returns:
|
||||
A context manager that either sets the token or does nothing.
|
||||
"""
|
||||
if not integration_token:
|
||||
return nullcontext()
|
||||
token = _platform_integration_token.set(integration_token)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_platform_integration_token.reset(token)
|
||||
|
||||
@contextmanager
|
||||
def _token_context() -> Generator[None, Any, None]:
|
||||
token = set_platform_integration_token(integration_token)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_platform_integration_token(token)
|
||||
|
||||
return _token_context()
|
||||
|
||||
_current_task_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"current_task_id", default=None
|
||||
|
||||
@@ -120,6 +120,52 @@ class FlowPlotEvent(FlowEvent):
|
||||
type: str = "flow_plot"
|
||||
|
||||
|
||||
class FlowInputRequestedEvent(FlowEvent):
|
||||
"""Event emitted when a flow requests user input via ``Flow.ask()``.
|
||||
|
||||
This event is emitted before the flow suspends waiting for user input,
|
||||
allowing UI frameworks and observability tools to know when a flow
|
||||
needs user interaction.
|
||||
|
||||
Attributes:
|
||||
flow_name: Name of the flow requesting input.
|
||||
method_name: Name of the flow method that called ``ask()``.
|
||||
message: The question or prompt being shown to the user.
|
||||
metadata: Optional metadata sent with the question (e.g., user ID,
|
||||
channel, session context).
|
||||
"""
|
||||
|
||||
method_name: str
|
||||
message: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
type: str = "flow_input_requested"
|
||||
|
||||
|
||||
class FlowInputReceivedEvent(FlowEvent):
|
||||
"""Event emitted when user input is received after ``Flow.ask()``.
|
||||
|
||||
This event is emitted after the user provides input (or the request
|
||||
times out), allowing UI frameworks and observability tools to track
|
||||
input collection.
|
||||
|
||||
Attributes:
|
||||
flow_name: Name of the flow that received input.
|
||||
method_name: Name of the flow method that called ``ask()``.
|
||||
message: The original question or prompt.
|
||||
response: The user's response, or None if timed out / unavailable.
|
||||
metadata: Optional metadata sent with the question.
|
||||
response_metadata: Optional metadata from the provider about the
|
||||
response (e.g., who responded, thread ID, timestamps).
|
||||
"""
|
||||
|
||||
method_name: str
|
||||
message: str
|
||||
response: str | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
response_metadata: dict[str, Any] | None = None
|
||||
type: str = "flow_input_received"
|
||||
|
||||
|
||||
class HumanFeedbackRequestedEvent(FlowEvent):
|
||||
"""Event emitted when human feedback is requested.
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from crewai.flow.async_feedback import (
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.flow_config import flow_config
|
||||
from crewai.flow.human_feedback import HumanFeedbackResult, human_feedback
|
||||
from crewai.flow.input_provider import InputProvider, InputResponse
|
||||
from crewai.flow.persistence import persist
|
||||
from crewai.flow.visualization import (
|
||||
FlowStructure,
|
||||
@@ -22,6 +23,8 @@ __all__ = [
|
||||
"HumanFeedbackPending",
|
||||
"HumanFeedbackProvider",
|
||||
"HumanFeedbackResult",
|
||||
"InputProvider",
|
||||
"InputResponse",
|
||||
"PendingFeedbackContext",
|
||||
"and_",
|
||||
"build_flow_structure",
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Default provider implementations for human feedback.
|
||||
"""Default provider implementations for human feedback and user input.
|
||||
|
||||
This module provides the ConsoleProvider, which is the default synchronous
|
||||
provider that collects feedback via console input.
|
||||
provider that collects both feedback (for ``@human_feedback``) and user input
|
||||
(for ``Flow.ask()``) via console.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -16,20 +17,23 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ConsoleProvider:
|
||||
"""Default synchronous console-based feedback provider.
|
||||
"""Default synchronous console-based provider for feedback and input.
|
||||
|
||||
This provider blocks execution and waits for console input from the user.
|
||||
It displays the method output with formatting and prompts for feedback.
|
||||
It serves two purposes:
|
||||
|
||||
- **Feedback** (``request_feedback``): Used by ``@human_feedback`` to
|
||||
display method output and collect review feedback.
|
||||
- **Input** (``request_input``): Used by ``Flow.ask()`` to prompt the
|
||||
user with a question and collect a response.
|
||||
|
||||
This is the default provider used when no custom provider is specified
|
||||
in the @human_feedback decorator.
|
||||
in the ``@human_feedback`` decorator or on the Flow's ``input_provider``.
|
||||
|
||||
Example:
|
||||
Example (feedback):
|
||||
```python
|
||||
from crewai.flow.async_feedback import ConsoleProvider
|
||||
|
||||
|
||||
# Explicitly use console provider
|
||||
@human_feedback(
|
||||
message="Review this:",
|
||||
provider=ConsoleProvider(),
|
||||
@@ -37,9 +41,20 @@ class ConsoleProvider:
|
||||
def my_method(self):
|
||||
return "Content to review"
|
||||
```
|
||||
|
||||
Example (input):
|
||||
```python
|
||||
from crewai.flow import Flow, start
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def gather_info(self):
|
||||
topic = self.ask("What topic should we research?")
|
||||
return topic
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, verbose: bool = True):
|
||||
def __init__(self, verbose: bool = True) -> None:
|
||||
"""Initialize the console provider.
|
||||
|
||||
Args:
|
||||
@@ -124,3 +139,55 @@ class ConsoleProvider:
|
||||
finally:
|
||||
# Resume live updates
|
||||
formatter.resume_live_updates()
|
||||
|
||||
def request_input(
|
||||
self,
|
||||
message: str,
|
||||
flow: Flow[Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
"""Request user input via console (blocking).
|
||||
|
||||
Displays the prompt message with formatting and waits for the user
|
||||
to type their response. Used by ``Flow.ask()``.
|
||||
|
||||
Unlike ``request_feedback``, this method does not display an
|
||||
"OUTPUT FOR REVIEW" panel or emit feedback-specific events (those
|
||||
are handled by ``ask()`` itself).
|
||||
|
||||
Args:
|
||||
message: The question or prompt to display to the user.
|
||||
flow: The Flow instance requesting input.
|
||||
metadata: Optional metadata from the caller. Ignored by the
|
||||
console provider (console has no concept of user routing).
|
||||
|
||||
Returns:
|
||||
The user's input as a stripped string. Returns empty string
|
||||
if user presses Enter without input. Never returns None
|
||||
(console input is always available).
|
||||
"""
|
||||
from crewai.events.event_listener import event_listener
|
||||
|
||||
# Pause live updates during human input
|
||||
formatter = event_listener.formatter
|
||||
formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
console = formatter.console
|
||||
|
||||
if self.verbose:
|
||||
console.print()
|
||||
console.print(message, style="yellow")
|
||||
console.print()
|
||||
|
||||
response = input(">>> \n").strip()
|
||||
else:
|
||||
response = input(f"{message} ").strip()
|
||||
|
||||
# Add line break after input so formatter output starts clean
|
||||
console.print()
|
||||
|
||||
return response
|
||||
finally:
|
||||
# Resume live updates
|
||||
formatter.resume_live_updates()
|
||||
|
||||
@@ -77,7 +77,7 @@ from crewai.flow.flow_wrappers import (
|
||||
StartMethod,
|
||||
)
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey
|
||||
from crewai.flow.types import FlowExecutionData, FlowMethodName, InputHistoryEntry, PendingListenerKey
|
||||
from crewai.flow.utils import (
|
||||
_extract_all_methods,
|
||||
_extract_all_methods_recursive,
|
||||
@@ -738,6 +738,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
tracing: bool | None = None
|
||||
stream: bool = False
|
||||
memory: Any = None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
|
||||
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
|
||||
|
||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
@@ -784,6 +785,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._pending_feedback_context: PendingFeedbackContext | None = None
|
||||
self.suppress_flow_events: bool = suppress_flow_events
|
||||
|
||||
# User input history (for self.ask())
|
||||
self._input_history: list[InputHistoryEntry] = []
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
self.tracing = tracing
|
||||
@@ -2119,15 +2123,24 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if future:
|
||||
self._event_futures.append(future)
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
result = await method(*args, **kwargs)
|
||||
else:
|
||||
# Run sync methods in thread pool for isolation
|
||||
# This allows Agent.kickoff() to work synchronously inside Flow methods
|
||||
import contextvars
|
||||
# Set method name in context so ask() can read it without
|
||||
# stack inspection. Must happen before copy_context() so the
|
||||
# value propagates into the thread pool for sync methods.
|
||||
from crewai.flow.flow_context import current_flow_method_name
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
|
||||
method_name_token = current_flow_method_name.set(method_name)
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
result = await method(*args, **kwargs)
|
||||
else:
|
||||
# Run sync methods in thread pool for isolation
|
||||
# This allows Agent.kickoff() to work synchronously inside Flow methods
|
||||
import contextvars
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
|
||||
finally:
|
||||
current_flow_method_name.reset(method_name_token)
|
||||
|
||||
# Auto-await coroutines returned from sync methods (enables AgentExecutor pattern)
|
||||
if asyncio.iscoroutine(result):
|
||||
@@ -2582,6 +2595,201 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
logger.error(f"Error executing listener {listener_name}: {e}")
|
||||
raise
|
||||
|
||||
# ── User Input (self.ask) ────────────────────────────────────────
|
||||
|
||||
def _resolve_input_provider(self) -> Any:
|
||||
"""Resolve the input provider using the priority chain.
|
||||
|
||||
Resolution order:
|
||||
1. ``self.input_provider`` (per-flow override)
|
||||
2. ``flow_config.input_provider`` (global default)
|
||||
3. ``ConsoleInputProvider()`` (built-in fallback)
|
||||
|
||||
Returns:
|
||||
An object implementing the ``InputProvider`` protocol.
|
||||
"""
|
||||
from crewai.flow.async_feedback.providers import ConsoleProvider
|
||||
from crewai.flow.flow_config import flow_config
|
||||
|
||||
if self.input_provider is not None:
|
||||
return self.input_provider
|
||||
if flow_config.input_provider is not None:
|
||||
return flow_config.input_provider
|
||||
return ConsoleProvider()
|
||||
|
||||
def _checkpoint_state_for_ask(self) -> None:
|
||||
"""Auto-checkpoint flow state before waiting for user input.
|
||||
|
||||
If persistence is configured, saves the current state so that
|
||||
``self.state`` is recoverable even if the process crashes while
|
||||
waiting for input.
|
||||
|
||||
This is best-effort: if persistence is not configured, this is a no-op.
|
||||
"""
|
||||
if self._persistence is None:
|
||||
return
|
||||
try:
|
||||
state_data = (
|
||||
self._state
|
||||
if isinstance(self._state, dict)
|
||||
else self._state.model_dump()
|
||||
)
|
||||
self._persistence.save_state(
|
||||
flow_uuid=self.flow_id,
|
||||
method_name="_ask_checkpoint",
|
||||
state_data=state_data,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to checkpoint state before ask()", exc_info=True)
|
||||
|
||||
def ask(
|
||||
self,
|
||||
message: str,
|
||||
timeout: float | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
"""Request input from the user during flow execution.
|
||||
|
||||
Blocks the current thread until the user provides input or the
|
||||
timeout expires. Works in both sync and async flow methods (the
|
||||
flow framework runs sync methods in a thread pool via
|
||||
``asyncio.to_thread``, so the event loop stays free).
|
||||
|
||||
Timeout ensures flows always terminate. When timeout expires,
|
||||
``None`` is returned, enabling the pattern::
|
||||
|
||||
while (msg := self.ask("You: ", timeout=300)) is not None:
|
||||
process(msg)
|
||||
|
||||
Before waiting for input, the current ``self.state`` is automatically
|
||||
checkpointed to persistence (if configured) for durability.
|
||||
|
||||
Args:
|
||||
message: The question or prompt to display to the user.
|
||||
timeout: Maximum seconds to wait for input. ``None`` means
|
||||
wait indefinitely. When timeout expires, returns ``None``.
|
||||
Note: timeout is best-effort for the provider call --
|
||||
``ask()`` returns ``None`` promptly, but the underlying
|
||||
``request_input()`` may continue running in a background
|
||||
thread until it completes naturally. Network providers
|
||||
should implement their own internal timeouts.
|
||||
metadata: Optional metadata to send to the input provider,
|
||||
such as user ID, channel, session context. The provider
|
||||
can use this to route the question to the right recipient.
|
||||
|
||||
Returns:
|
||||
The user's input as a string, or ``None`` on timeout, disconnect,
|
||||
or provider error. Empty string ``""`` means the user pressed
|
||||
Enter without typing (intentional empty input).
|
||||
|
||||
Example:
|
||||
```python
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def gather_info(self):
|
||||
topic = self.ask(
|
||||
"What topic should we research?",
|
||||
metadata={"user_id": "u123", "channel": "#research"},
|
||||
)
|
||||
if topic is None:
|
||||
return "No input received"
|
||||
return topic
|
||||
```
|
||||
"""
|
||||
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
|
||||
from datetime import datetime
|
||||
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowInputReceivedEvent,
|
||||
FlowInputRequestedEvent,
|
||||
)
|
||||
from crewai.flow.flow_context import current_flow_method_name
|
||||
from crewai.flow.input_provider import InputResponse
|
||||
|
||||
method_name = current_flow_method_name.get("unknown")
|
||||
|
||||
# Emit input requested event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
FlowInputRequestedEvent(
|
||||
type="flow_input_requested",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
method_name=method_name,
|
||||
message=message,
|
||||
metadata=metadata,
|
||||
),
|
||||
)
|
||||
|
||||
# Auto-checkpoint state before waiting
|
||||
self._checkpoint_state_for_ask()
|
||||
|
||||
provider = self._resolve_input_provider()
|
||||
raw: str | InputResponse | None = None
|
||||
|
||||
try:
|
||||
if timeout is not None:
|
||||
# Manual executor management to avoid shutdown(wait=True)
|
||||
# deadlock when the provider call outlives the timeout.
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
future = executor.submit(
|
||||
provider.request_input, message, self, metadata
|
||||
)
|
||||
try:
|
||||
raw = future.result(timeout=timeout)
|
||||
except FuturesTimeoutError:
|
||||
future.cancel()
|
||||
raw = None
|
||||
finally:
|
||||
# wait=False so we don't block if the provider is still
|
||||
# running (e.g. input() stuck waiting for user).
|
||||
# cancel_futures=True cleans up any queued-but-not-started tasks.
|
||||
executor.shutdown(wait=False, cancel_futures=True)
|
||||
else:
|
||||
raw = provider.request_input(message, self, metadata=metadata)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception:
|
||||
logger.debug("Input provider error in ask()", exc_info=True)
|
||||
raw = None
|
||||
|
||||
# Normalize provider response: str, InputResponse, or None
|
||||
response: str | None = None
|
||||
response_metadata: dict[str, Any] | None = None
|
||||
|
||||
if isinstance(raw, InputResponse):
|
||||
response = raw.text
|
||||
response_metadata = raw.metadata
|
||||
elif isinstance(raw, str):
|
||||
response = raw
|
||||
else:
|
||||
response = None
|
||||
|
||||
# Record in history
|
||||
self._input_history.append({
|
||||
"message": message,
|
||||
"response": response,
|
||||
"method_name": method_name,
|
||||
"timestamp": datetime.now(),
|
||||
"metadata": metadata,
|
||||
"response_metadata": response_metadata,
|
||||
})
|
||||
|
||||
# Emit input received event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
FlowInputReceivedEvent(
|
||||
type="flow_input_received",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
method_name=method_name,
|
||||
message=message,
|
||||
response=response,
|
||||
metadata=metadata,
|
||||
response_metadata=response_metadata,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _request_human_feedback(
|
||||
self,
|
||||
message: str,
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.async_feedback.types import HumanFeedbackProvider
|
||||
from crewai.flow.input_provider import InputProvider
|
||||
|
||||
|
||||
class FlowConfig:
|
||||
@@ -20,10 +21,15 @@ class FlowConfig:
|
||||
hitl_provider: The human-in-the-loop feedback provider.
|
||||
Defaults to None (uses console input).
|
||||
Can be overridden by deployments at startup.
|
||||
input_provider: The input provider used by ``Flow.ask()``.
|
||||
Defaults to None (uses ``ConsoleProvider``).
|
||||
Can be overridden by
|
||||
deployments at startup.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._hitl_provider: HumanFeedbackProvider | None = None
|
||||
self._input_provider: InputProvider | None = None
|
||||
|
||||
@property
|
||||
def hitl_provider(self) -> Any:
|
||||
@@ -35,6 +41,32 @@ class FlowConfig:
|
||||
"""Set the HITL provider."""
|
||||
self._hitl_provider = provider
|
||||
|
||||
@property
|
||||
def input_provider(self) -> Any:
|
||||
"""Get the configured input provider for ``Flow.ask()``.
|
||||
|
||||
Returns:
|
||||
The configured InputProvider instance, or None if not set
|
||||
(in which case ``ConsoleInputProvider`` is used as default).
|
||||
"""
|
||||
return self._input_provider
|
||||
|
||||
@input_provider.setter
|
||||
def input_provider(self, provider: Any) -> None:
|
||||
"""Set the input provider for ``Flow.ask()``.
|
||||
|
||||
Args:
|
||||
provider: An object implementing the ``InputProvider`` protocol.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from crewai.flow import flow_config
|
||||
|
||||
flow_config.input_provider = WebSocketInputProvider(...)
|
||||
```
|
||||
"""
|
||||
self._input_provider = provider
|
||||
|
||||
|
||||
# Singleton instance
|
||||
flow_config = FlowConfig()
|
||||
|
||||
@@ -14,3 +14,7 @@ current_flow_request_id: contextvars.ContextVar[str | None] = contextvars.Contex
|
||||
current_flow_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"flow_id", default=None
|
||||
)
|
||||
|
||||
current_flow_method_name: contextvars.ContextVar[str] = contextvars.ContextVar(
|
||||
"flow_method_name", default="unknown"
|
||||
)
|
||||
|
||||
151
lib/crewai/src/crewai/flow/input_provider.py
Normal file
151
lib/crewai/src/crewai/flow/input_provider.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Input provider protocol for Flow.ask().
|
||||
|
||||
This module provides the InputProvider protocol and InputResponse dataclass
|
||||
used by Flow.ask() to request input from users during flow execution.
|
||||
|
||||
The default implementation is ``ConsoleProvider`` (from
|
||||
``crewai.flow.async_feedback.providers``), which serves both feedback
|
||||
and input collection via console.
|
||||
|
||||
Example (default console input):
|
||||
```python
|
||||
from crewai.flow import Flow, start
|
||||
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
def gather_info(self):
|
||||
topic = self.ask("What topic should we research?")
|
||||
return topic
|
||||
```
|
||||
|
||||
Example (custom provider with metadata):
|
||||
```python
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.flow.input_provider import InputProvider, InputResponse
|
||||
|
||||
|
||||
class SlackProvider:
|
||||
def request_input(self, message, flow, metadata=None):
|
||||
channel = metadata.get("channel", "#general") if metadata else "#general"
|
||||
thread = self.post_question(channel, message)
|
||||
reply = self.wait_for_reply(thread)
|
||||
return InputResponse(
|
||||
text=reply.text,
|
||||
metadata={"responded_by": reply.user_id, "thread_id": thread.id},
|
||||
)
|
||||
|
||||
|
||||
class MyFlow(Flow):
|
||||
input_provider = SlackProvider()
|
||||
|
||||
@start()
|
||||
def gather_info(self):
|
||||
topic = self.ask("What topic?", metadata={"channel": "#research"})
|
||||
return topic
|
||||
```
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputResponse:
|
||||
"""Response from an InputProvider, optionally carrying metadata.
|
||||
|
||||
Simple providers can just return a string from ``request_input()``.
|
||||
Providers that need to send metadata back (e.g., who responded,
|
||||
thread ID, external timestamps) return an ``InputResponse`` instead.
|
||||
|
||||
``ask()`` normalizes both cases -- callers always get ``str | None``.
|
||||
The response metadata is stored in ``_input_history`` and emitted
|
||||
in ``FlowInputReceivedEvent``.
|
||||
|
||||
Attributes:
|
||||
text: The user's input text, or None if unavailable.
|
||||
metadata: Optional metadata from the provider about the response
|
||||
(e.g., who responded, thread ID, timestamps).
|
||||
|
||||
Example:
|
||||
```python
|
||||
class MyProvider:
|
||||
def request_input(self, message, flow, metadata=None):
|
||||
response = get_response_from_external_system(message)
|
||||
return InputResponse(
|
||||
text=response.text,
|
||||
metadata={"responded_by": response.user_id},
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
text: str | None
|
||||
metadata: dict[str, Any] | None = field(default=None)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class InputProvider(Protocol):
|
||||
"""Protocol for user input collection strategies.
|
||||
|
||||
Implement this protocol to create custom input providers that integrate
|
||||
with external systems like websockets, web UIs, Slack, or custom APIs.
|
||||
|
||||
The default provider is ``ConsoleProvider``, which blocks waiting for
|
||||
console input via Python's built-in ``input()`` function.
|
||||
|
||||
Providers are always synchronous. The flow framework runs sync methods
|
||||
in a thread pool (via ``asyncio.to_thread``), so ``ask()`` never blocks
|
||||
the event loop even inside async flow methods.
|
||||
|
||||
Providers can return either:
|
||||
- ``str | None`` for simple cases (no response metadata)
|
||||
- ``InputResponse`` when they need to send metadata back with the answer
|
||||
|
||||
Example (simple):
|
||||
```python
|
||||
class SimpleProvider:
|
||||
def request_input(self, message: str, flow: Flow) -> str | None:
|
||||
return input(message)
|
||||
```
|
||||
|
||||
Example (with metadata):
|
||||
```python
|
||||
class SlackProvider:
|
||||
def request_input(self, message, flow, metadata=None):
|
||||
channel = metadata.get("channel") if metadata else "#general"
|
||||
reply = self.post_and_wait(channel, message)
|
||||
return InputResponse(
|
||||
text=reply.text,
|
||||
metadata={"responded_by": reply.user_id},
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def request_input(
|
||||
self,
|
||||
message: str,
|
||||
flow: Flow[Any],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str | InputResponse | None:
|
||||
"""Request input from the user.
|
||||
|
||||
Args:
|
||||
message: The question or prompt to display to the user.
|
||||
flow: The Flow instance requesting input. Can be used to
|
||||
access flow state, name, or other context.
|
||||
metadata: Optional metadata from the caller, such as user ID,
|
||||
channel, session context, etc. Providers can use this to
|
||||
route the question to the right recipient.
|
||||
|
||||
Returns:
|
||||
The user's input as a string, an ``InputResponse`` with text
|
||||
and optional response metadata, or None if input is unavailable
|
||||
(e.g., user cancelled, connection dropped).
|
||||
"""
|
||||
...
|
||||
@@ -4,6 +4,7 @@ This module contains TypedDict definitions and type aliases used throughout
|
||||
the Flow system.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
@@ -101,6 +102,30 @@ class FlowData(TypedDict):
|
||||
flow_methods_attributes: list[FlowMethodData]
|
||||
|
||||
|
||||
class InputHistoryEntry(TypedDict):
|
||||
"""A single entry in the flow's input history from ``self.ask()``.
|
||||
|
||||
Each call to ``Flow.ask()`` appends one entry recording the question,
|
||||
the user's response, which method asked, and any metadata exchanged
|
||||
between the caller and the input provider.
|
||||
|
||||
Attributes:
|
||||
message: The question or prompt that was displayed to the user.
|
||||
response: The user's response, or None on timeout/error.
|
||||
method_name: The flow method that called ``ask()``.
|
||||
timestamp: When the input was received.
|
||||
metadata: Metadata sent with the question (caller to provider).
|
||||
response_metadata: Metadata received with the answer (provider to caller).
|
||||
"""
|
||||
|
||||
message: str
|
||||
response: str | None
|
||||
method_name: str
|
||||
timestamp: datetime
|
||||
metadata: dict[str, Any] | None
|
||||
response_metadata: dict[str, Any] | None
|
||||
|
||||
|
||||
class FlowExecutionData(TypedDict):
|
||||
"""Flow execution data.
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ def mock_crew():
|
||||
def mock_get_crews(mock_crew):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
|
||||
) as mock_get_crew:
|
||||
) as mock_get_crew, mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[]
|
||||
):
|
||||
yield mock_get_crew
|
||||
|
||||
|
||||
@@ -193,6 +195,79 @@ def test_reset_memory_from_many_crews(mock_get_crews, runner):
|
||||
assert call_count == 2, "reset_memories should have been called twice"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flow():
|
||||
_mock = mock.Mock()
|
||||
_mock.name = "TestFlow"
|
||||
_mock.memory = mock.Mock()
|
||||
_mock.memory.reset = mock.Mock()
|
||||
return _mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_flows(mock_flow):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
|
||||
) as mock_get_flow, mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[]
|
||||
):
|
||||
yield mock_get_flow
|
||||
|
||||
|
||||
def test_reset_flow_memory(mock_get_flows, mock_flow, runner):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
mock_flow.memory.reset.assert_called_once()
|
||||
assert "[Flow (TestFlow)] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_flow_all_memories(mock_get_flows, mock_flow, runner):
|
||||
result = runner.invoke(reset_memories, ["-a"])
|
||||
mock_flow.memory.reset.assert_called_once()
|
||||
assert "[Flow (TestFlow)] Reset memories command has been completed." in result.output
|
||||
|
||||
|
||||
def test_reset_flow_knowledge_no_effect(mock_get_flows, mock_flow, runner):
|
||||
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||
mock_flow.memory.reset.assert_not_called()
|
||||
assert "[Flow (TestFlow)]" not in result.output
|
||||
|
||||
|
||||
def test_reset_no_crew_or_flow_found(runner):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[]
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
assert "No crew or flow found." in result.output
|
||||
|
||||
|
||||
def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner):
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
|
||||
), mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
|
||||
mock_flow.memory.reset.assert_called_once()
|
||||
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
|
||||
assert "[Flow (TestFlow)] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_flow_memory_none(runner):
|
||||
mock_flow = mock.Mock()
|
||||
mock_flow.name = "NoMemFlow"
|
||||
mock_flow.memory = None
|
||||
with mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_crews", return_value=[]
|
||||
), mock.patch(
|
||||
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
|
||||
):
|
||||
result = runner.invoke(reset_memories, ["-m"])
|
||||
assert "[Flow (NoMemFlow)] Memory has been reset." in result.output
|
||||
|
||||
|
||||
def test_reset_no_memory_flags(runner):
|
||||
result = runner.invoke(
|
||||
reset_memories,
|
||||
|
||||
@@ -7,139 +7,215 @@ import pytest
|
||||
from crewai.context import (
|
||||
_platform_integration_token,
|
||||
get_platform_integration_token,
|
||||
platform_integration_context,
|
||||
reset_platform_integration_token,
|
||||
platform_context,
|
||||
set_platform_integration_token,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_context():
|
||||
"""Fixture to ensure clean context state for each test."""
|
||||
_platform_integration_token.set(None)
|
||||
yield
|
||||
_platform_integration_token.set(None)
|
||||
class TestPlatformIntegrationToken:
|
||||
def setup_method(self):
|
||||
_platform_integration_token.set(None)
|
||||
|
||||
def teardown_method(self):
|
||||
_platform_integration_token.set(None)
|
||||
|
||||
class TestContextVariableCore:
|
||||
"""Test core context variable functionality (set/get/reset)."""
|
||||
|
||||
def test_set_and_get_token(self, clean_context):
|
||||
"""Test basic token setting and retrieval."""
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_set_platform_integration_token(self):
|
||||
test_token = "test-token-123"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
context_token = set_platform_integration_token(test_token)
|
||||
set_platform_integration_token(test_token)
|
||||
|
||||
assert get_platform_integration_token() == test_token
|
||||
assert context_token is not None
|
||||
|
||||
def test_reset_token_restores_previous_state(self, clean_context):
|
||||
"""Test that reset properly restores previous context state."""
|
||||
token1 = "token-1"
|
||||
token2 = "token-2"
|
||||
def test_get_platform_integration_token_from_context_var(self):
|
||||
test_token = "context-var-token"
|
||||
|
||||
context_token1 = set_platform_integration_token(token1)
|
||||
assert get_platform_integration_token() == token1
|
||||
_platform_integration_token.set(test_token)
|
||||
|
||||
context_token2 = set_platform_integration_token(token2)
|
||||
assert get_platform_integration_token() == token2
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
reset_platform_integration_token(context_token2)
|
||||
assert get_platform_integration_token() == token1
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token-456"})
|
||||
def test_get_platform_integration_token_from_env_var(self):
|
||||
assert _platform_integration_token.get() is None
|
||||
|
||||
reset_platform_integration_token(context_token1)
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_nested_token_management(self, clean_context):
|
||||
"""Test proper token management with deeply nested contexts."""
|
||||
tokens = ["token-1", "token-2", "token-3"]
|
||||
context_tokens = []
|
||||
|
||||
for token in tokens:
|
||||
context_tokens.append(set_platform_integration_token(token))
|
||||
assert get_platform_integration_token() == token
|
||||
|
||||
for i in range(len(tokens) - 1, 0, -1):
|
||||
reset_platform_integration_token(context_tokens[i])
|
||||
assert get_platform_integration_token() == tokens[i - 1]
|
||||
|
||||
reset_platform_integration_token(context_tokens[0])
|
||||
assert get_platform_integration_token() is None
|
||||
assert get_platform_integration_token() == "env-token-456"
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-token"})
|
||||
def test_context_module_ignores_environment_variables(self, clean_context):
|
||||
"""Test that context module only returns context values, not env vars."""
|
||||
# Context module should not read environment variables
|
||||
assert get_platform_integration_token() is None
|
||||
def test_context_var_takes_precedence_over_env_var(self):
|
||||
context_token = "context-token"
|
||||
|
||||
# Only context variable should be returned
|
||||
set_platform_integration_token("context-token")
|
||||
assert get_platform_integration_token() == "context-token"
|
||||
set_platform_integration_token(context_token)
|
||||
|
||||
assert get_platform_integration_token() == context_token
|
||||
|
||||
class TestPlatformIntegrationContext:
|
||||
"""Test platform integration context manager behavior."""
|
||||
|
||||
def test_basic_context_manager_usage(self, clean_context):
|
||||
"""Test basic context manager functionality."""
|
||||
test_token = "context-token"
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_get_platform_integration_token_returns_none_when_not_set(self):
|
||||
assert _platform_integration_token.get() is None
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_integration_context(test_token):
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_platform_context_manager_basic_usage(self):
|
||||
test_token = "context-manager-token"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(test_token):
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
@pytest.mark.parametrize("falsy_value", [None, "", False, 0])
|
||||
def test_falsy_values_return_nullcontext(self, clean_context, falsy_value):
|
||||
"""Test that falsy values return nullcontext (no-op)."""
|
||||
# Set initial token to verify nullcontext doesn't affect it
|
||||
initial_token = "initial-token"
|
||||
initial_context_token = set_platform_integration_token(initial_token)
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_platform_context_manager_nested_contexts(self):
|
||||
"""Test nested platform_context context managers."""
|
||||
outer_token = "outer-token"
|
||||
inner_token = "inner-token"
|
||||
|
||||
try:
|
||||
with platform_integration_context(falsy_value):
|
||||
# Should preserve existing context (nullcontext behavior)
|
||||
assert get_platform_integration_token() == initial_token
|
||||
|
||||
# Should still have initial token after nullcontext
|
||||
assert get_platform_integration_token() == initial_token
|
||||
finally:
|
||||
reset_platform_integration_token(initial_context_token)
|
||||
|
||||
@pytest.mark.parametrize("truthy_value", ["token", "123", " ", "0"])
|
||||
def test_truthy_values_create_context(self, clean_context, truthy_value):
|
||||
"""Test that truthy values create proper context."""
|
||||
with platform_integration_context(truthy_value):
|
||||
assert get_platform_integration_token() == truthy_value
|
||||
|
||||
# Should be cleaned up
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_context_preserves_existing_token(self, clean_context):
|
||||
"""Test that context manager preserves existing token when exiting."""
|
||||
existing_token = "existing-token"
|
||||
with platform_context(outer_token):
|
||||
assert get_platform_integration_token() == outer_token
|
||||
|
||||
with platform_context(inner_token):
|
||||
assert get_platform_integration_token() == inner_token
|
||||
|
||||
assert get_platform_integration_token() == outer_token
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_platform_context_manager_preserves_existing_token(self):
|
||||
"""Test that platform_context preserves existing token when exiting."""
|
||||
initial_token = "initial-token"
|
||||
context_token = "context-token"
|
||||
|
||||
existing_context_token = set_platform_integration_token(existing_token)
|
||||
set_platform_integration_token(initial_token)
|
||||
assert get_platform_integration_token() == initial_token
|
||||
|
||||
try:
|
||||
with platform_integration_context(context_token):
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
|
||||
assert get_platform_integration_token() == initial_token
|
||||
|
||||
def test_platform_context_manager_exception_handling(self):
|
||||
"""Test that platform_context properly resets token even when exception occurs."""
|
||||
initial_token = "initial-token"
|
||||
context_token = "context-token"
|
||||
|
||||
set_platform_integration_token(initial_token)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
raise ValueError("Test exception")
|
||||
|
||||
assert get_platform_integration_token() == existing_token
|
||||
finally:
|
||||
reset_platform_integration_token(existing_context_token)
|
||||
assert get_platform_integration_token() == initial_token
|
||||
|
||||
def test_context_manager_return_type(self, clean_context):
|
||||
"""Test that context manager returns proper types for both cases."""
|
||||
# Both should be usable as context managers
|
||||
valid_ctx = platform_integration_context("token")
|
||||
none_ctx = platform_integration_context(None)
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_platform_context_manager_with_none_initial_state(self):
|
||||
"""Test platform_context when initial state is None."""
|
||||
context_token = "context-token"
|
||||
|
||||
assert hasattr(valid_ctx, '__enter__')
|
||||
assert hasattr(valid_ctx, '__exit__')
|
||||
assert hasattr(none_ctx, '__enter__')
|
||||
assert hasattr(none_ctx, '__exit__')
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
raise RuntimeError("Test exception")
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "env-backup"})
|
||||
def test_platform_context_with_env_fallback(self):
|
||||
"""Test platform_context interaction with environment variable fallback."""
|
||||
context_token = "context-token"
|
||||
|
||||
assert get_platform_integration_token() == "env-backup"
|
||||
|
||||
with platform_context(context_token):
|
||||
assert get_platform_integration_token() == context_token
|
||||
|
||||
assert get_platform_integration_token() == "env-backup"
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_multiple_sequential_context_managers(self):
|
||||
"""Test multiple sequential uses of platform_context."""
|
||||
token1 = "token-1"
|
||||
token2 = "token-2"
|
||||
token3 = "token-3"
|
||||
|
||||
with platform_context(token1):
|
||||
assert get_platform_integration_token() == token1
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(token2):
|
||||
assert get_platform_integration_token() == token2
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
with platform_context(token3):
|
||||
assert get_platform_integration_token() == token3
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
def test_empty_string_token(self):
|
||||
empty_token = ""
|
||||
|
||||
set_platform_integration_token(empty_token)
|
||||
assert get_platform_integration_token() == ""
|
||||
|
||||
with platform_context(empty_token):
|
||||
assert get_platform_integration_token() == ""
|
||||
|
||||
def test_special_characters_in_token(self):
|
||||
special_token = "token-with-!@#$%^&*()_+-={}[]|\\:;\"'<>?,./"
|
||||
|
||||
set_platform_integration_token(special_token)
|
||||
assert get_platform_integration_token() == special_token
|
||||
|
||||
with platform_context(special_token):
|
||||
assert get_platform_integration_token() == special_token
|
||||
|
||||
def test_very_long_token(self):
|
||||
long_token = "a" * 10000
|
||||
|
||||
set_platform_integration_token(long_token)
|
||||
assert get_platform_integration_token() == long_token
|
||||
|
||||
with platform_context(long_token):
|
||||
assert get_platform_integration_token() == long_token
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLATFORM_INTEGRATION_TOKEN": ""})
|
||||
def test_empty_env_var(self):
|
||||
assert _platform_integration_token.get() is None
|
||||
assert get_platform_integration_token() == ""
|
||||
|
||||
@patch("crewai.context.os.getenv")
|
||||
def test_env_var_access_error_handling(self, mock_getenv):
|
||||
mock_getenv.side_effect = OSError("Environment access error")
|
||||
|
||||
with pytest.raises(OSError):
|
||||
get_platform_integration_token()
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
def test_context_var_isolation_between_tests(self):
|
||||
"""Test that context variable changes don't leak between test methods."""
|
||||
test_token = "isolation-test-token"
|
||||
|
||||
assert get_platform_integration_token() is None
|
||||
|
||||
set_platform_integration_token(test_token)
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
def test_context_manager_return_value(self):
|
||||
"""Test that platform_context can be used in with statement with return value."""
|
||||
test_token = "return-value-token"
|
||||
|
||||
with platform_context(test_token):
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
with platform_context(test_token) as ctx:
|
||||
assert ctx is None
|
||||
assert get_platform_integration_token() == test_token
|
||||
|
||||
1152
lib/crewai/tests/test_flow_ask.py
Normal file
1152
lib/crewai/tests/test_flow_ask.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow import Flow, HumanFeedbackResult, human_feedback, listen, start
|
||||
from crewai.flow import Flow, HumanFeedbackResult, human_feedback, listen, or_, start
|
||||
from crewai.flow.flow import FlowState
|
||||
|
||||
|
||||
@@ -271,6 +271,182 @@ class TestMultiStepFlows:
|
||||
assert len(flow.human_feedback_history) == 1
|
||||
assert flow.human_feedback_history[0].outcome == "rejected"
|
||||
|
||||
def test_hitl_self_loop_routes_back_to_same_method(self):
|
||||
"""Test that a HITL router can loop back to itself via its own emit outcome.
|
||||
|
||||
Pattern: review_work listens to or_("do_work", "review") and emits
|
||||
["review", "approved"]. When the human rejects (outcome="review"),
|
||||
the method should re-execute. When approved, the flow should continue
|
||||
to the approve_work listener.
|
||||
"""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class SelfLoopFlow(Flow):
|
||||
@start()
|
||||
def initial_func(self):
|
||||
execution_order.append("initial_func")
|
||||
return "initial"
|
||||
|
||||
@listen(initial_func)
|
||||
def do_work(self):
|
||||
execution_order.append("do_work")
|
||||
return "work output"
|
||||
|
||||
@human_feedback(
|
||||
message="Do you approve this content?",
|
||||
emit=["review", "approved"],
|
||||
llm="gpt-4o-mini",
|
||||
default_outcome="approved",
|
||||
)
|
||||
@listen(or_("do_work", "review"))
|
||||
def review_work(self):
|
||||
execution_order.append("review_work")
|
||||
return "content for review"
|
||||
|
||||
@listen("approved")
|
||||
def approve_work(self):
|
||||
execution_order.append("approve_work")
|
||||
return "published"
|
||||
|
||||
flow = SelfLoopFlow()
|
||||
|
||||
# First call: human rejects (outcome="review") -> self-loop
|
||||
# Second call: human approves (outcome="approved") -> continue
|
||||
with (
|
||||
patch.object(
|
||||
flow,
|
||||
"_request_human_feedback",
|
||||
side_effect=["needs changes", "looks good"],
|
||||
),
|
||||
patch.object(
|
||||
flow,
|
||||
"_collapse_to_outcome",
|
||||
side_effect=["review", "approved"],
|
||||
),
|
||||
):
|
||||
result = flow.kickoff()
|
||||
|
||||
assert execution_order == [
|
||||
"initial_func",
|
||||
"do_work",
|
||||
"review_work", # first review -> rejected (review)
|
||||
"review_work", # second review -> approved
|
||||
"approve_work",
|
||||
]
|
||||
assert result == "published"
|
||||
assert len(flow.human_feedback_history) == 2
|
||||
assert flow.human_feedback_history[0].outcome == "review"
|
||||
assert flow.human_feedback_history[1].outcome == "approved"
|
||||
|
||||
def test_hitl_self_loop_multiple_rejections(self):
|
||||
"""Test that a HITL router can loop back multiple times before approving.
|
||||
|
||||
Verifies the self-loop works for more than one rejection cycle.
|
||||
"""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class MultiRejectFlow(Flow):
|
||||
@start()
|
||||
def generate(self):
|
||||
execution_order.append("generate")
|
||||
return "draft"
|
||||
|
||||
@human_feedback(
|
||||
message="Review this content:",
|
||||
emit=["revise", "approved"],
|
||||
llm="gpt-4o-mini",
|
||||
default_outcome="approved",
|
||||
)
|
||||
@listen(or_("generate", "revise"))
|
||||
def review(self):
|
||||
execution_order.append("review")
|
||||
return "content v" + str(execution_order.count("review"))
|
||||
|
||||
@listen("approved")
|
||||
def publish(self):
|
||||
execution_order.append("publish")
|
||||
return "published"
|
||||
|
||||
flow = MultiRejectFlow()
|
||||
|
||||
# Three rejections, then approval
|
||||
with (
|
||||
patch.object(
|
||||
flow,
|
||||
"_request_human_feedback",
|
||||
side_effect=["bad", "still bad", "not yet", "great"],
|
||||
),
|
||||
patch.object(
|
||||
flow,
|
||||
"_collapse_to_outcome",
|
||||
side_effect=["revise", "revise", "revise", "approved"],
|
||||
),
|
||||
):
|
||||
result = flow.kickoff()
|
||||
|
||||
assert execution_order == [
|
||||
"generate",
|
||||
"review", # 1st review -> revise
|
||||
"review", # 2nd review -> revise
|
||||
"review", # 3rd review -> revise
|
||||
"review", # 4th review -> approved
|
||||
"publish",
|
||||
]
|
||||
assert result == "published"
|
||||
assert len(flow.human_feedback_history) == 4
|
||||
assert [r.outcome for r in flow.human_feedback_history] == [
|
||||
"revise", "revise", "revise", "approved"
|
||||
]
|
||||
|
||||
def test_hitl_self_loop_immediate_approval(self):
|
||||
"""Test that a HITL self-loop flow works when approved on the first try.
|
||||
|
||||
No looping occurs -- the flow should proceed straight through.
|
||||
"""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ImmediateApprovalFlow(Flow):
|
||||
@start()
|
||||
def generate(self):
|
||||
execution_order.append("generate")
|
||||
return "perfect draft"
|
||||
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["revise", "approved"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
@listen(or_("generate", "revise"))
|
||||
def review(self):
|
||||
execution_order.append("review")
|
||||
return "content"
|
||||
|
||||
@listen("approved")
|
||||
def publish(self):
|
||||
execution_order.append("publish")
|
||||
return "published"
|
||||
|
||||
flow = ImmediateApprovalFlow()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
flow,
|
||||
"_request_human_feedback",
|
||||
return_value="perfect",
|
||||
),
|
||||
patch.object(
|
||||
flow,
|
||||
"_collapse_to_outcome",
|
||||
return_value="approved",
|
||||
),
|
||||
):
|
||||
result = flow.kickoff()
|
||||
|
||||
assert execution_order == ["generate", "review", "publish"]
|
||||
assert result == "published"
|
||||
assert len(flow.human_feedback_history) == 1
|
||||
assert flow.human_feedback_history[0].outcome == "approved"
|
||||
|
||||
def test_router_and_non_router_listeners_for_same_outcome(self):
|
||||
"""Test that both router and non-router listeners fire for the same outcome."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
Reference in New Issue
Block a user