fix: add async HITL support and chained-router tests
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled

asynchronous human-in-the-loop handling and related fixes.

- Extend human_input provider with async support: AsyncExecutorContext, handle_feedback_async, async prompt helpers (_prompt_input_async, _async_readline), and async training/regular feedback loops in SyncHumanInputProvider.
- Add async handler methods in CrewAgentExecutor and AgentExecutor (_ahandle_human_feedback, _ainvoke_loop) to integrate async provider flows.
- Change PlusAPI.get_agent to an async httpx call and adapt caller in agent_utils to run it via asyncio.run.
- Simplify listener execution in flow.Flow to correctly pass HumanFeedbackResult to listeners and unify execution path for router outcomes.
- Remove deprecated types/hitl.py definitions.
- Add tests covering chained router feedback, rejected paths, and mixed router/non-router listeners to prevent regressions.
This commit is contained in:
Greyson LaLonde
2026-02-06 16:29:27 -05:00
committed by GitHub
parent 7d498b29be
commit f6fa04528a
9 changed files with 473 additions and 112 deletions

View File

@@ -1009,7 +1009,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
raise
if self.ask_for_human_input:
formatted_answer = self._handle_human_feedback(formatted_answer)
formatted_answer = await self._ahandle_human_feedback(formatted_answer)
self._create_short_term_memory(formatted_answer)
self._create_long_term_memory(formatted_answer)
@@ -1508,6 +1508,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
provider = get_provider()
return provider.handle_feedback(formatted_answer, self)
async def _ahandle_human_feedback(
self, formatted_answer: AgentFinish
) -> AgentFinish:
"""Process human feedback asynchronously via the configured provider.
Args:
formatted_answer: Initial agent result.
Returns:
Final answer after feedback.
"""
provider = get_provider()
return await provider.handle_feedback_async(formatted_answer, self)
def _is_training_mode(self) -> bool:
"""Check if training mode is active.

View File

@@ -1,6 +1,8 @@
import os
from typing import Any
from urllib.parse import urljoin
import os
import httpx
import requests
from crewai.cli.config import Settings
@@ -33,7 +35,11 @@ class PlusAPI:
if settings.org_uuid:
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
self.base_url = os.getenv("CREWAI_PLUS_URL") or str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL
self.base_url = (
os.getenv("CREWAI_PLUS_URL")
or str(settings.enterprise_base_url)
or DEFAULT_CREWAI_ENTERPRISE_URL
)
def _make_request(
self, method: str, endpoint: str, **kwargs: Any
@@ -49,8 +55,10 @@ class PlusAPI:
def get_tool(self, handle: str) -> requests.Response:
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
def get_agent(self, handle: str) -> requests.Response:
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
async def get_agent(self, handle: str) -> httpx.Response:
url = urljoin(self.base_url, f"{self.AGENTS_RESOURCE}/{handle}")
async with httpx.AsyncClient() as client:
return await client.get(url, headers=self.headers)
def publish_tool(
self,

View File

@@ -2,7 +2,9 @@
from __future__ import annotations
import asyncio
from contextvars import ContextVar, Token
import sys
from typing import TYPE_CHECKING, Protocol, runtime_checkable
@@ -46,13 +48,21 @@ class ExecutorContext(Protocol):
...
class AsyncExecutorContext(ExecutorContext, Protocol):
"""Extended context for executors that support async invocation."""
async def _ainvoke_loop(self) -> AgentFinish:
"""Invoke the agent loop asynchronously and return the result."""
...
@runtime_checkable
class HumanInputProvider(Protocol):
"""Protocol for human input handling.
Implementations handle the full feedback flow:
- Sync: prompt user, loop until satisfied
- Async: raise exception for external handling
- Async: use non-blocking I/O and async invoke loop
"""
def setup_messages(self, context: ExecutorContext) -> bool:
@@ -86,7 +96,7 @@ class HumanInputProvider(Protocol):
formatted_answer: AgentFinish,
context: ExecutorContext,
) -> AgentFinish:
"""Handle the full human feedback flow.
"""Handle the full human feedback flow synchronously.
Args:
formatted_answer: The agent's current answer.
@@ -100,6 +110,25 @@ class HumanInputProvider(Protocol):
"""
...
async def handle_feedback_async(
self,
formatted_answer: AgentFinish,
context: AsyncExecutorContext,
) -> AgentFinish:
"""Handle the full human feedback flow asynchronously.
Uses non-blocking I/O for user prompts and async invoke loop
for agent re-execution.
Args:
formatted_answer: The agent's current answer.
context: Async executor context for callbacks.
Returns:
The final answer after feedback processing.
"""
...
@staticmethod
def _get_output_string(answer: AgentFinish) -> str:
"""Extract output string from answer.
@@ -116,7 +145,7 @@ class HumanInputProvider(Protocol):
class SyncHumanInputProvider(HumanInputProvider):
"""Default synchronous human input via terminal."""
"""Default human input provider with sync and async support."""
def setup_messages(self, context: ExecutorContext) -> bool:
"""Use standard message setup.
@@ -157,6 +186,33 @@ class SyncHumanInputProvider(HumanInputProvider):
return self._handle_regular_feedback(formatted_answer, feedback, context)
async def handle_feedback_async(
self,
formatted_answer: AgentFinish,
context: AsyncExecutorContext,
) -> AgentFinish:
"""Handle feedback asynchronously without blocking the event loop.
Args:
formatted_answer: The agent's current answer.
context: Async executor context for callbacks.
Returns:
The final answer after feedback processing.
"""
feedback = await self._prompt_input_async(context.crew)
if context._is_training_mode():
return await self._handle_training_feedback_async(
formatted_answer, feedback, context
)
return await self._handle_regular_feedback_async(
formatted_answer, feedback, context
)
# ── Sync helpers ──────────────────────────────────────────────────
@staticmethod
def _handle_training_feedback(
initial_answer: AgentFinish,
@@ -209,6 +265,62 @@ class SyncHumanInputProvider(HumanInputProvider):
return answer
# ── Async helpers ─────────────────────────────────────────────────
@staticmethod
async def _handle_training_feedback_async(
initial_answer: AgentFinish,
feedback: str,
context: AsyncExecutorContext,
) -> AgentFinish:
"""Process training feedback asynchronously (single iteration).
Args:
initial_answer: The agent's initial answer.
feedback: Human feedback string.
context: Async executor context for callbacks.
Returns:
Improved answer after processing feedback.
"""
context._handle_crew_training_output(initial_answer, feedback)
context.messages.append(context._format_feedback_message(feedback))
improved_answer = await context._ainvoke_loop()
context._handle_crew_training_output(improved_answer)
context.ask_for_human_input = False
return improved_answer
async def _handle_regular_feedback_async(
self,
current_answer: AgentFinish,
initial_feedback: str,
context: AsyncExecutorContext,
) -> AgentFinish:
"""Process regular feedback with async iteration loop.
Args:
current_answer: The agent's current answer.
initial_feedback: Initial human feedback string.
context: Async executor context for callbacks.
Returns:
Final answer after all feedback iterations.
"""
feedback = initial_feedback
answer = current_answer
while context.ask_for_human_input:
if feedback.strip() == "":
context.ask_for_human_input = False
else:
context.messages.append(context._format_feedback_message(feedback))
answer = await context._ainvoke_loop()
feedback = await self._prompt_input_async(context.crew)
return answer
# ── I/O ───────────────────────────────────────────────────────────
@staticmethod
def _prompt_input(crew: Crew | None) -> str:
"""Show rich panel and prompt for input.
@@ -262,6 +374,79 @@ class SyncHumanInputProvider(HumanInputProvider):
finally:
formatter.resume_live_updates()
@staticmethod
async def _prompt_input_async(crew: Crew | None) -> str:
"""Show rich panel and prompt for input without blocking the event loop.
Args:
crew: The crew instance for context.
Returns:
User input string from terminal.
"""
from rich.panel import Panel
from rich.text import Text
from crewai.events.event_listener import event_listener
formatter = event_listener.formatter
formatter.pause_live_updates()
try:
if crew and getattr(crew, "_train", False):
prompt_text = (
"TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
"This will be used to train better versions of the agent.\n"
"Please provide detailed feedback about the result quality and reasoning process."
)
title = "🎓 Training Feedback Required"
else:
prompt_text = (
"Provide feedback on the Final Result above.\n\n"
"• If you are happy with the result, simply hit Enter without typing anything.\n"
"• Otherwise, provide specific improvement requests.\n"
"• You can provide multiple rounds of feedback until satisfied."
)
title = "💬 Human Feedback Required"
content = Text()
content.append(prompt_text, style="yellow")
prompt_panel = Panel(
content,
title=title,
border_style="yellow",
padding=(1, 2),
)
formatter.console.print(prompt_panel)
response = await _async_readline()
if response.strip() != "":
formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
return response
finally:
formatter.resume_live_updates()
async def _async_readline() -> str:
"""Read a line from stdin using the event loop's native I/O.
Falls back to asyncio.to_thread on platforms where piping stdin
is unsupported.
Returns:
The line read from stdin, with trailing newline stripped.
"""
loop = asyncio.get_running_loop()
try:
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
raw = await reader.readline()
return raw.decode().rstrip("\n")
except (OSError, NotImplementedError, ValueError):
return await asyncio.to_thread(input)
_provider: ContextVar[HumanInputProvider | None] = ContextVar(
"human_input_provider",

View File

@@ -258,6 +258,22 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
raise RuntimeError("Agent loop did not produce a final answer")
return answer
async def _ainvoke_loop(self) -> AgentFinish:
"""Invoke the agent loop asynchronously and return the result.
Required by AsyncExecutorContext protocol.
"""
self._state.iterations = 0
self._state.is_finished = False
self._state.current_answer = None
await self.akickoff()
answer = self._state.current_answer
if not isinstance(answer, AgentFinish):
raise RuntimeError("Agent loop did not produce a final answer")
return answer
def _format_feedback_message(self, feedback: str) -> LLMMessage:
"""Format feedback as a message for the LLM.
@@ -1173,7 +1189,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
)
if self.state.ask_for_human_input:
formatted_answer = self._handle_human_feedback(formatted_answer)
formatted_answer = await self._ahandle_human_feedback(formatted_answer)
self._create_short_term_memory(formatted_answer)
self._create_long_term_memory(formatted_answer)
@@ -1390,6 +1406,20 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
provider = get_provider()
return provider.handle_feedback(formatted_answer, self)
async def _ahandle_human_feedback(
self, formatted_answer: AgentFinish
) -> AgentFinish:
"""Process human feedback asynchronously and refine answer.
Args:
formatted_answer: Initial agent result.
Returns:
Final answer after feedback.
"""
provider = get_provider()
return await provider.handle_feedback_async(formatted_answer, self)
def _is_training_mode(self) -> bool:
"""Check if training mode is active.

View File

@@ -1934,40 +1934,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
await self._execute_listeners(start_method_name, result, finished_event_id)
# Then execute listeners for the router result (e.g., "approved")
router_result_trigger = FlowMethodName(str(result))
listeners_for_result = self._find_triggered_methods(
router_result_trigger, router_only=False
listener_result = (
self.last_human_feedback
if self.last_human_feedback is not None
else result
)
await self._execute_listeners(
router_result_trigger, listener_result, finished_event_id
)
if listeners_for_result:
# Pass the HumanFeedbackResult if available
listener_result = (
self.last_human_feedback
if self.last_human_feedback is not None
else result
)
racing_group = self._get_racing_group_for_listeners(
listeners_for_result
)
if racing_group:
racing_members, _ = racing_group
other_listeners = [
name
for name in listeners_for_result
if name not in racing_members
]
await self._execute_racing_listeners(
racing_members,
other_listeners,
listener_result,
finished_event_id,
)
else:
tasks = [
self._execute_single_listener(
listener_name, listener_result, finished_event_id
)
for listener_name in listeners_for_result
]
await asyncio.gather(*tasks)
else:
await self._execute_listeners(start_method_name, result, finished_event_id)

View File

@@ -1,37 +0,0 @@
"""Human-in-the-loop (HITL) type definitions.
This module provides type definitions for human-in-the-loop interactions
in crew executions.
"""
from typing import TypedDict
class HITLResumeInfo(TypedDict, total=False):
"""HITL resume information passed from flow to crew.
Attributes:
task_id: Unique identifier for the task.
crew_execution_id: Unique identifier for the crew execution.
task_key: Key identifying the specific task.
task_output: Output from the task before human intervention.
human_feedback: Feedback provided by the human.
previous_messages: History of messages in the conversation.
"""
task_id: str
crew_execution_id: str
task_key: str
task_output: str
human_feedback: str
previous_messages: list[dict[str, str]]
class CrewInputsWithHITL(TypedDict, total=False):
"""Crew inputs that may contain HITL resume information.
Attributes:
_hitl_resume: Optional HITL resume information for continuing execution.
"""
_hitl_resume: HITLResumeInfo

View File

@@ -832,7 +832,7 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
client = PlusAPI(api_key=get_auth_token())
_print_current_organization()
response = client.get_agent(from_repository)
response = asyncio.run(client.get_agent(from_repository))
if response.status_code == 404:
raise AgentRepositoryError(
f"Agent {from_repository} does not exist, make sure the name is correct or the agent is available on your organization."

View File

@@ -1,6 +1,8 @@
import os
import unittest
from unittest.mock import ANY, MagicMock, patch
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pytest
from crewai.cli.plus_api import PlusAPI
@@ -68,37 +70,6 @@ class TestPlusAPI(unittest.TestCase):
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_get_agent(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get_agent("test_agent_handle")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/agents/test_agent_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.Settings")
@patch("requests.Session.request")
def test_get_agent_with_org_uuid(self, mock_make_request, mock_settings_class):
mock_settings = MagicMock()
mock_settings.org_uuid = self.org_uuid
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
mock_settings_class.return_value = mock_settings
# re-initialize Client
self.api = PlusAPI(self.api_key)
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get_agent("test_agent_handle")
self.assert_request_with_org_id(
mock_make_request, "GET", "/crewai_plus/api/v1/agents/test_agent_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
mock_response = MagicMock()
@@ -338,3 +309,49 @@ class TestPlusAPI(unittest.TestCase):
custom_api.base_url,
"https://custom-url-from-env.com",
)
@pytest.mark.asyncio
@patch("httpx.AsyncClient")
async def test_get_agent(mock_async_client_class):
api = PlusAPI("test_api_key")
mock_response = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
response = await api.get_agent("test_agent_handle")
mock_client_instance.get.assert_called_once_with(
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
headers=api.headers,
)
assert response == mock_response
@pytest.mark.asyncio
@patch("httpx.AsyncClient")
@patch("crewai.cli.plus_api.Settings")
async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class):
org_uuid = "test-org-uuid"
mock_settings = MagicMock()
mock_settings.org_uuid = org_uuid
mock_settings.enterprise_base_url = os.getenv("CREWAI_PLUS_URL")
mock_settings_class.return_value = mock_settings
api = PlusAPI("test_api_key")
mock_response = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
response = await api.get_agent("test_agent_handle")
mock_client_instance.get.assert_called_once_with(
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
headers=api.headers,
)
assert "X-Crewai-Organization-Id" in api.headers
assert api.headers["X-Crewai-Organization-Id"] == org_uuid
assert response == mock_response

View File

@@ -157,6 +157,176 @@ class TestMultiStepFlows:
assert execution_order == ["generate", "review", "finalize"]
def test_chained_router_feedback_steps(self):
"""Test that a router outcome can trigger another router method.
Regression test: @listen("outcome") combined with @human_feedback(emit=...)
creates a method that is both a listener and a router. The flow must find
and execute it when the upstream router emits the matching outcome.
"""
execution_order: list[str] = []
class ChainedRouterFlow(Flow):
@start()
@human_feedback(
message="First review:",
emit=["approved", "rejected"],
llm="gpt-4o-mini",
)
def draft(self):
execution_order.append("draft")
return "draft content"
@listen("approved")
@human_feedback(
message="Final review:",
emit=["publish", "revise"],
llm="gpt-4o-mini",
)
def final_review(self, prev: HumanFeedbackResult):
execution_order.append("final_review")
return "final content"
@listen("rejected")
def on_rejected(self, prev: HumanFeedbackResult):
execution_order.append("on_rejected")
return "rejected"
@listen("publish")
def on_publish(self, prev: HumanFeedbackResult):
execution_order.append("on_publish")
return "published"
@listen("revise")
def on_revise(self, prev: HumanFeedbackResult):
execution_order.append("on_revise")
return "revised"
flow = ChainedRouterFlow()
with (
patch.object(
flow,
"_request_human_feedback",
side_effect=["looks good", "ship it"],
),
patch.object(
flow,
"_collapse_to_outcome",
side_effect=["approved", "publish"],
),
):
result = flow.kickoff()
assert execution_order == ["draft", "final_review", "on_publish"]
assert result == "published"
assert len(flow.human_feedback_history) == 2
assert flow.human_feedback_history[0].outcome == "approved"
assert flow.human_feedback_history[1].outcome == "publish"
def test_chained_router_rejected_path(self):
"""Test that a start-router outcome routes to a non-router listener."""
execution_order: list[str] = []
class ChainedRouterFlow(Flow):
@start()
@human_feedback(
message="Review:",
emit=["approved", "rejected"],
llm="gpt-4o-mini",
)
def draft(self):
execution_order.append("draft")
return "draft"
@listen("approved")
@human_feedback(
message="Final:",
emit=["publish", "revise"],
llm="gpt-4o-mini",
)
def final_review(self, prev: HumanFeedbackResult):
execution_order.append("final_review")
return "final"
@listen("rejected")
def on_rejected(self, prev: HumanFeedbackResult):
execution_order.append("on_rejected")
return "rejected"
flow = ChainedRouterFlow()
with (
patch.object(
flow, "_request_human_feedback", return_value="bad"
),
patch.object(
flow, "_collapse_to_outcome", return_value="rejected"
),
):
result = flow.kickoff()
assert execution_order == ["draft", "on_rejected"]
assert result == "rejected"
assert len(flow.human_feedback_history) == 1
assert flow.human_feedback_history[0].outcome == "rejected"
def test_router_and_non_router_listeners_for_same_outcome(self):
"""Test that both router and non-router listeners fire for the same outcome."""
execution_order: list[str] = []
class MixedListenerFlow(Flow):
@start()
@human_feedback(
message="Review:",
emit=["approved", "rejected"],
llm="gpt-4o-mini",
)
def draft(self):
execution_order.append("draft")
return "draft"
@listen("approved")
@human_feedback(
message="Final:",
emit=["publish", "revise"],
llm="gpt-4o-mini",
)
def router_listener(self, prev: HumanFeedbackResult):
execution_order.append("router_listener")
return "final"
@listen("approved")
def plain_listener(self, prev: HumanFeedbackResult):
execution_order.append("plain_listener")
return "logged"
@listen("publish")
def on_publish(self, prev: HumanFeedbackResult):
execution_order.append("on_publish")
return "published"
flow = MixedListenerFlow()
with (
patch.object(
flow,
"_request_human_feedback",
side_effect=["approve it", "publish it"],
),
patch.object(
flow,
"_collapse_to_outcome",
side_effect=["approved", "publish"],
),
):
flow.kickoff()
assert "draft" in execution_order
assert "router_listener" in execution_order
assert "plain_listener" in execution_order
assert "on_publish" in execution_order
class TestStateManagement:
"""Tests for state management with human feedback."""