mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-21 02:12:36 +00:00
fix: add async HITL support and chained-router tests
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Build uv cache / build-cache (3.10) (push) Has been cancelled
Build uv cache / build-cache (3.11) (push) Has been cancelled
Build uv cache / build-cache (3.12) (push) Has been cancelled
Build uv cache / build-cache (3.13) (push) Has been cancelled
asynchronous human-in-the-loop handling and related fixes. - Extend human_input provider with async support: AsyncExecutorContext, handle_feedback_async, async prompt helpers (_prompt_input_async, _async_readline), and async training/regular feedback loops in SyncHumanInputProvider. - Add async handler methods in CrewAgentExecutor and AgentExecutor (_ahandle_human_feedback, _ainvoke_loop) to integrate async provider flows. - Change PlusAPI.get_agent to an async httpx call and adapt caller in agent_utils to run it via asyncio.run. - Simplify listener execution in flow.Flow to correctly pass HumanFeedbackResult to listeners and unify execution path for router outcomes. - Remove deprecated types/hitl.py definitions. - Add tests covering chained router feedback, rejected paths, and mixed router/non-router listeners to prevent regressions.
This commit is contained in:
@@ -1009,7 +1009,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
raise
|
||||
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
formatted_answer = await self._ahandle_human_feedback(formatted_answer)
|
||||
|
||||
self._create_short_term_memory(formatted_answer)
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
@@ -1508,6 +1508,20 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
provider = get_provider()
|
||||
return provider.handle_feedback(formatted_answer, self)
|
||||
|
||||
async def _ahandle_human_feedback(
|
||||
self, formatted_answer: AgentFinish
|
||||
) -> AgentFinish:
|
||||
"""Process human feedback asynchronously via the configured provider.
|
||||
|
||||
Args:
|
||||
formatted_answer: Initial agent result.
|
||||
|
||||
Returns:
|
||||
Final answer after feedback.
|
||||
"""
|
||||
provider = get_provider()
|
||||
return await provider.handle_feedback_async(formatted_answer, self)
|
||||
|
||||
def _is_training_mode(self) -> bool:
|
||||
"""Check if training mode is active.
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from crewai.cli.config import Settings
|
||||
@@ -33,7 +35,11 @@ class PlusAPI:
|
||||
if settings.org_uuid:
|
||||
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
|
||||
|
||||
self.base_url = os.getenv("CREWAI_PLUS_URL") or str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
self.base_url = (
|
||||
os.getenv("CREWAI_PLUS_URL")
|
||||
or str(settings.enterprise_base_url)
|
||||
or DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
)
|
||||
|
||||
def _make_request(
|
||||
self, method: str, endpoint: str, **kwargs: Any
|
||||
@@ -49,8 +55,10 @@ class PlusAPI:
|
||||
def get_tool(self, handle: str) -> requests.Response:
|
||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||
|
||||
def get_agent(self, handle: str) -> requests.Response:
|
||||
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
|
||||
async def get_agent(self, handle: str) -> httpx.Response:
|
||||
url = urljoin(self.base_url, f"{self.AGENTS_RESOURCE}/{handle}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
return await client.get(url, headers=self.headers)
|
||||
|
||||
def publish_tool(
|
||||
self,
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextvars import ContextVar, Token
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@@ -46,13 +48,21 @@ class ExecutorContext(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class AsyncExecutorContext(ExecutorContext, Protocol):
|
||||
"""Extended context for executors that support async invocation."""
|
||||
|
||||
async def _ainvoke_loop(self) -> AgentFinish:
|
||||
"""Invoke the agent loop asynchronously and return the result."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HumanInputProvider(Protocol):
|
||||
"""Protocol for human input handling.
|
||||
|
||||
Implementations handle the full feedback flow:
|
||||
- Sync: prompt user, loop until satisfied
|
||||
- Async: raise exception for external handling
|
||||
- Async: use non-blocking I/O and async invoke loop
|
||||
"""
|
||||
|
||||
def setup_messages(self, context: ExecutorContext) -> bool:
|
||||
@@ -86,7 +96,7 @@ class HumanInputProvider(Protocol):
|
||||
formatted_answer: AgentFinish,
|
||||
context: ExecutorContext,
|
||||
) -> AgentFinish:
|
||||
"""Handle the full human feedback flow.
|
||||
"""Handle the full human feedback flow synchronously.
|
||||
|
||||
Args:
|
||||
formatted_answer: The agent's current answer.
|
||||
@@ -100,6 +110,25 @@ class HumanInputProvider(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
async def handle_feedback_async(
|
||||
self,
|
||||
formatted_answer: AgentFinish,
|
||||
context: AsyncExecutorContext,
|
||||
) -> AgentFinish:
|
||||
"""Handle the full human feedback flow asynchronously.
|
||||
|
||||
Uses non-blocking I/O for user prompts and async invoke loop
|
||||
for agent re-execution.
|
||||
|
||||
Args:
|
||||
formatted_answer: The agent's current answer.
|
||||
context: Async executor context for callbacks.
|
||||
|
||||
Returns:
|
||||
The final answer after feedback processing.
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _get_output_string(answer: AgentFinish) -> str:
|
||||
"""Extract output string from answer.
|
||||
@@ -116,7 +145,7 @@ class HumanInputProvider(Protocol):
|
||||
|
||||
|
||||
class SyncHumanInputProvider(HumanInputProvider):
|
||||
"""Default synchronous human input via terminal."""
|
||||
"""Default human input provider with sync and async support."""
|
||||
|
||||
def setup_messages(self, context: ExecutorContext) -> bool:
|
||||
"""Use standard message setup.
|
||||
@@ -157,6 +186,33 @@ class SyncHumanInputProvider(HumanInputProvider):
|
||||
|
||||
return self._handle_regular_feedback(formatted_answer, feedback, context)
|
||||
|
||||
async def handle_feedback_async(
|
||||
self,
|
||||
formatted_answer: AgentFinish,
|
||||
context: AsyncExecutorContext,
|
||||
) -> AgentFinish:
|
||||
"""Handle feedback asynchronously without blocking the event loop.
|
||||
|
||||
Args:
|
||||
formatted_answer: The agent's current answer.
|
||||
context: Async executor context for callbacks.
|
||||
|
||||
Returns:
|
||||
The final answer after feedback processing.
|
||||
"""
|
||||
feedback = await self._prompt_input_async(context.crew)
|
||||
|
||||
if context._is_training_mode():
|
||||
return await self._handle_training_feedback_async(
|
||||
formatted_answer, feedback, context
|
||||
)
|
||||
|
||||
return await self._handle_regular_feedback_async(
|
||||
formatted_answer, feedback, context
|
||||
)
|
||||
|
||||
# ── Sync helpers ──────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _handle_training_feedback(
|
||||
initial_answer: AgentFinish,
|
||||
@@ -209,6 +265,62 @@ class SyncHumanInputProvider(HumanInputProvider):
|
||||
|
||||
return answer
|
||||
|
||||
# ── Async helpers ─────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
async def _handle_training_feedback_async(
|
||||
initial_answer: AgentFinish,
|
||||
feedback: str,
|
||||
context: AsyncExecutorContext,
|
||||
) -> AgentFinish:
|
||||
"""Process training feedback asynchronously (single iteration).
|
||||
|
||||
Args:
|
||||
initial_answer: The agent's initial answer.
|
||||
feedback: Human feedback string.
|
||||
context: Async executor context for callbacks.
|
||||
|
||||
Returns:
|
||||
Improved answer after processing feedback.
|
||||
"""
|
||||
context._handle_crew_training_output(initial_answer, feedback)
|
||||
context.messages.append(context._format_feedback_message(feedback))
|
||||
improved_answer = await context._ainvoke_loop()
|
||||
context._handle_crew_training_output(improved_answer)
|
||||
context.ask_for_human_input = False
|
||||
return improved_answer
|
||||
|
||||
async def _handle_regular_feedback_async(
|
||||
self,
|
||||
current_answer: AgentFinish,
|
||||
initial_feedback: str,
|
||||
context: AsyncExecutorContext,
|
||||
) -> AgentFinish:
|
||||
"""Process regular feedback with async iteration loop.
|
||||
|
||||
Args:
|
||||
current_answer: The agent's current answer.
|
||||
initial_feedback: Initial human feedback string.
|
||||
context: Async executor context for callbacks.
|
||||
|
||||
Returns:
|
||||
Final answer after all feedback iterations.
|
||||
"""
|
||||
feedback = initial_feedback
|
||||
answer = current_answer
|
||||
|
||||
while context.ask_for_human_input:
|
||||
if feedback.strip() == "":
|
||||
context.ask_for_human_input = False
|
||||
else:
|
||||
context.messages.append(context._format_feedback_message(feedback))
|
||||
answer = await context._ainvoke_loop()
|
||||
feedback = await self._prompt_input_async(context.crew)
|
||||
|
||||
return answer
|
||||
|
||||
# ── I/O ───────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _prompt_input(crew: Crew | None) -> str:
|
||||
"""Show rich panel and prompt for input.
|
||||
@@ -262,6 +374,79 @@ class SyncHumanInputProvider(HumanInputProvider):
|
||||
finally:
|
||||
formatter.resume_live_updates()
|
||||
|
||||
@staticmethod
|
||||
async def _prompt_input_async(crew: Crew | None) -> str:
|
||||
"""Show rich panel and prompt for input without blocking the event loop.
|
||||
|
||||
Args:
|
||||
crew: The crew instance for context.
|
||||
|
||||
Returns:
|
||||
User input string from terminal.
|
||||
"""
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
|
||||
formatter = event_listener.formatter
|
||||
formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
if crew and getattr(crew, "_train", False):
|
||||
prompt_text = (
|
||||
"TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
|
||||
"This will be used to train better versions of the agent.\n"
|
||||
"Please provide detailed feedback about the result quality and reasoning process."
|
||||
)
|
||||
title = "🎓 Training Feedback Required"
|
||||
else:
|
||||
prompt_text = (
|
||||
"Provide feedback on the Final Result above.\n\n"
|
||||
"• If you are happy with the result, simply hit Enter without typing anything.\n"
|
||||
"• Otherwise, provide specific improvement requests.\n"
|
||||
"• You can provide multiple rounds of feedback until satisfied."
|
||||
)
|
||||
title = "💬 Human Feedback Required"
|
||||
|
||||
content = Text()
|
||||
content.append(prompt_text, style="yellow")
|
||||
|
||||
prompt_panel = Panel(
|
||||
content,
|
||||
title=title,
|
||||
border_style="yellow",
|
||||
padding=(1, 2),
|
||||
)
|
||||
formatter.console.print(prompt_panel)
|
||||
|
||||
response = await _async_readline()
|
||||
if response.strip() != "":
|
||||
formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
|
||||
return response
|
||||
finally:
|
||||
formatter.resume_live_updates()
|
||||
|
||||
|
||||
async def _async_readline() -> str:
|
||||
"""Read a line from stdin using the event loop's native I/O.
|
||||
|
||||
Falls back to asyncio.to_thread on platforms where piping stdin
|
||||
is unsupported.
|
||||
|
||||
Returns:
|
||||
The line read from stdin, with trailing newline stripped.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
|
||||
raw = await reader.readline()
|
||||
return raw.decode().rstrip("\n")
|
||||
except (OSError, NotImplementedError, ValueError):
|
||||
return await asyncio.to_thread(input)
|
||||
|
||||
|
||||
_provider: ContextVar[HumanInputProvider | None] = ContextVar(
|
||||
"human_input_provider",
|
||||
|
||||
@@ -258,6 +258,22 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
raise RuntimeError("Agent loop did not produce a final answer")
|
||||
return answer
|
||||
|
||||
async def _ainvoke_loop(self) -> AgentFinish:
|
||||
"""Invoke the agent loop asynchronously and return the result.
|
||||
|
||||
Required by AsyncExecutorContext protocol.
|
||||
"""
|
||||
self._state.iterations = 0
|
||||
self._state.is_finished = False
|
||||
self._state.current_answer = None
|
||||
|
||||
await self.akickoff()
|
||||
|
||||
answer = self._state.current_answer
|
||||
if not isinstance(answer, AgentFinish):
|
||||
raise RuntimeError("Agent loop did not produce a final answer")
|
||||
return answer
|
||||
|
||||
def _format_feedback_message(self, feedback: str) -> LLMMessage:
|
||||
"""Format feedback as a message for the LLM.
|
||||
|
||||
@@ -1173,7 +1189,7 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
)
|
||||
|
||||
if self.state.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
formatted_answer = await self._ahandle_human_feedback(formatted_answer)
|
||||
|
||||
self._create_short_term_memory(formatted_answer)
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
@@ -1390,6 +1406,20 @@ class AgentExecutor(Flow[AgentReActState], CrewAgentExecutorMixin):
|
||||
provider = get_provider()
|
||||
return provider.handle_feedback(formatted_answer, self)
|
||||
|
||||
async def _ahandle_human_feedback(
|
||||
self, formatted_answer: AgentFinish
|
||||
) -> AgentFinish:
|
||||
"""Process human feedback asynchronously and refine answer.
|
||||
|
||||
Args:
|
||||
formatted_answer: Initial agent result.
|
||||
|
||||
Returns:
|
||||
Final answer after feedback.
|
||||
"""
|
||||
provider = get_provider()
|
||||
return await provider.handle_feedback_async(formatted_answer, self)
|
||||
|
||||
def _is_training_mode(self) -> bool:
|
||||
"""Check if training mode is active.
|
||||
|
||||
|
||||
@@ -1934,40 +1934,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
await self._execute_listeners(start_method_name, result, finished_event_id)
|
||||
# Then execute listeners for the router result (e.g., "approved")
|
||||
router_result_trigger = FlowMethodName(str(result))
|
||||
listeners_for_result = self._find_triggered_methods(
|
||||
router_result_trigger, router_only=False
|
||||
listener_result = (
|
||||
self.last_human_feedback
|
||||
if self.last_human_feedback is not None
|
||||
else result
|
||||
)
|
||||
await self._execute_listeners(
|
||||
router_result_trigger, listener_result, finished_event_id
|
||||
)
|
||||
if listeners_for_result:
|
||||
# Pass the HumanFeedbackResult if available
|
||||
listener_result = (
|
||||
self.last_human_feedback
|
||||
if self.last_human_feedback is not None
|
||||
else result
|
||||
)
|
||||
racing_group = self._get_racing_group_for_listeners(
|
||||
listeners_for_result
|
||||
)
|
||||
if racing_group:
|
||||
racing_members, _ = racing_group
|
||||
other_listeners = [
|
||||
name
|
||||
for name in listeners_for_result
|
||||
if name not in racing_members
|
||||
]
|
||||
await self._execute_racing_listeners(
|
||||
racing_members,
|
||||
other_listeners,
|
||||
listener_result,
|
||||
finished_event_id,
|
||||
)
|
||||
else:
|
||||
tasks = [
|
||||
self._execute_single_listener(
|
||||
listener_name, listener_result, finished_event_id
|
||||
)
|
||||
for listener_name in listeners_for_result
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
else:
|
||||
await self._execute_listeners(start_method_name, result, finished_event_id)
|
||||
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Human-in-the-loop (HITL) type definitions.
|
||||
|
||||
This module provides type definitions for human-in-the-loop interactions
|
||||
in crew executions.
|
||||
"""
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class HITLResumeInfo(TypedDict, total=False):
|
||||
"""HITL resume information passed from flow to crew.
|
||||
|
||||
Attributes:
|
||||
task_id: Unique identifier for the task.
|
||||
crew_execution_id: Unique identifier for the crew execution.
|
||||
task_key: Key identifying the specific task.
|
||||
task_output: Output from the task before human intervention.
|
||||
human_feedback: Feedback provided by the human.
|
||||
previous_messages: History of messages in the conversation.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
crew_execution_id: str
|
||||
task_key: str
|
||||
task_output: str
|
||||
human_feedback: str
|
||||
previous_messages: list[dict[str, str]]
|
||||
|
||||
|
||||
class CrewInputsWithHITL(TypedDict, total=False):
|
||||
"""Crew inputs that may contain HITL resume information.
|
||||
|
||||
Attributes:
|
||||
_hitl_resume: Optional HITL resume information for continuing execution.
|
||||
"""
|
||||
|
||||
_hitl_resume: HITLResumeInfo
|
||||
@@ -832,7 +832,7 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
|
||||
|
||||
client = PlusAPI(api_key=get_auth_token())
|
||||
_print_current_organization()
|
||||
response = client.get_agent(from_repository)
|
||||
response = asyncio.run(client.get_agent(from_repository))
|
||||
if response.status_code == 404:
|
||||
raise AgentRepositoryError(
|
||||
f"Agent {from_repository} does not exist, make sure the name is correct or the agent is available on your organization."
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
@@ -68,37 +70,6 @@ class TestPlusAPI(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_get_agent(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.get_agent("test_agent_handle")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/agents/test_agent_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
@patch("requests.Session.request")
|
||||
def test_get_agent_with_org_uuid(self, mock_make_request, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
|
||||
mock_settings_class.return_value = mock_settings
|
||||
# re-initialize Client
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.get_agent("test_agent_handle")
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_make_request, "GET", "/crewai_plus/api/v1/agents/test_agent_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_get_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
@@ -338,3 +309,49 @@ class TestPlusAPI(unittest.TestCase):
|
||||
custom_api.base_url,
|
||||
"https://custom-url-from-env.com",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("httpx.AsyncClient")
|
||||
async def test_get_agent(mock_async_client_class):
|
||||
api = PlusAPI("test_api_key")
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
response = await api.get_agent("test_agent_handle")
|
||||
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
|
||||
headers=api.headers,
|
||||
)
|
||||
assert response == mock_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("httpx.AsyncClient")
|
||||
@patch("crewai.cli.plus_api.Settings")
|
||||
async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class):
|
||||
org_uuid = "test-org-uuid"
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv("CREWAI_PLUS_URL")
|
||||
mock_settings_class.return_value = mock_settings
|
||||
|
||||
api = PlusAPI("test_api_key")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
response = await api.get_agent("test_agent_handle")
|
||||
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
|
||||
headers=api.headers,
|
||||
)
|
||||
assert "X-Crewai-Organization-Id" in api.headers
|
||||
assert api.headers["X-Crewai-Organization-Id"] == org_uuid
|
||||
assert response == mock_response
|
||||
|
||||
@@ -157,6 +157,176 @@ class TestMultiStepFlows:
|
||||
|
||||
assert execution_order == ["generate", "review", "finalize"]
|
||||
|
||||
def test_chained_router_feedback_steps(self):
|
||||
"""Test that a router outcome can trigger another router method.
|
||||
|
||||
Regression test: @listen("outcome") combined with @human_feedback(emit=...)
|
||||
creates a method that is both a listener and a router. The flow must find
|
||||
and execute it when the upstream router emits the matching outcome.
|
||||
"""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ChainedRouterFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="First review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def draft(self):
|
||||
execution_order.append("draft")
|
||||
return "draft content"
|
||||
|
||||
@listen("approved")
|
||||
@human_feedback(
|
||||
message="Final review:",
|
||||
emit=["publish", "revise"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def final_review(self, prev: HumanFeedbackResult):
|
||||
execution_order.append("final_review")
|
||||
return "final content"
|
||||
|
||||
@listen("rejected")
|
||||
def on_rejected(self, prev: HumanFeedbackResult):
|
||||
execution_order.append("on_rejected")
|
||||
return "rejected"
|
||||
|
||||
@listen("publish")
|
||||
def on_publish(self, prev: HumanFeedbackResult):
|
||||
execution_order.append("on_publish")
|
||||
return "published"
|
||||
|
||||
@listen("revise")
|
||||
def on_revise(self, prev: HumanFeedbackResult):
|
||||
execution_order.append("on_revise")
|
||||
return "revised"
|
||||
|
||||
flow = ChainedRouterFlow()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
flow,
|
||||
"_request_human_feedback",
|
||||
side_effect=["looks good", "ship it"],
|
||||
),
|
||||
patch.object(
|
||||
flow,
|
||||
"_collapse_to_outcome",
|
||||
side_effect=["approved", "publish"],
|
||||
),
|
||||
):
|
||||
result = flow.kickoff()
|
||||
|
||||
assert execution_order == ["draft", "final_review", "on_publish"]
|
||||
assert result == "published"
|
||||
assert len(flow.human_feedback_history) == 2
|
||||
assert flow.human_feedback_history[0].outcome == "approved"
|
||||
assert flow.human_feedback_history[1].outcome == "publish"
|
||||
|
||||
def test_chained_router_rejected_path(self):
|
||||
"""Test that a start-router outcome routes to a non-router listener."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class ChainedRouterFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def draft(self):
|
||||
execution_order.append("draft")
|
||||
return "draft"
|
||||
|
||||
@listen("approved")
|
||||
@human_feedback(
|
||||
message="Final:",
|
||||
emit=["publish", "revise"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def final_review(self, prev: HumanFeedbackResult):
|
||||
execution_order.append("final_review")
|
||||
return "final"
|
||||
|
||||
@listen("rejected")
|
||||
def on_rejected(self, prev: HumanFeedbackResult):
|
||||
execution_order.append("on_rejected")
|
||||
return "rejected"
|
||||
|
||||
flow = ChainedRouterFlow()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
flow, "_request_human_feedback", return_value="bad"
|
||||
),
|
||||
patch.object(
|
||||
flow, "_collapse_to_outcome", return_value="rejected"
|
||||
),
|
||||
):
|
||||
result = flow.kickoff()
|
||||
|
||||
assert execution_order == ["draft", "on_rejected"]
|
||||
assert result == "rejected"
|
||||
assert len(flow.human_feedback_history) == 1
|
||||
assert flow.human_feedback_history[0].outcome == "rejected"
|
||||
|
||||
def test_router_and_non_router_listeners_for_same_outcome(self):
|
||||
"""Test that both router and non-router listeners fire for the same outcome."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
class MixedListenerFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def draft(self):
|
||||
execution_order.append("draft")
|
||||
return "draft"
|
||||
|
||||
@listen("approved")
|
||||
@human_feedback(
|
||||
message="Final:",
|
||||
emit=["publish", "revise"],
|
||||
llm="gpt-4o-mini",
|
||||
)
|
||||
def router_listener(self, prev: HumanFeedbackResult):
|
||||
execution_order.append("router_listener")
|
||||
return "final"
|
||||
|
||||
@listen("approved")
|
||||
def plain_listener(self, prev: HumanFeedbackResult):
|
||||
execution_order.append("plain_listener")
|
||||
return "logged"
|
||||
|
||||
@listen("publish")
|
||||
def on_publish(self, prev: HumanFeedbackResult):
|
||||
execution_order.append("on_publish")
|
||||
return "published"
|
||||
|
||||
flow = MixedListenerFlow()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
flow,
|
||||
"_request_human_feedback",
|
||||
side_effect=["approve it", "publish it"],
|
||||
),
|
||||
patch.object(
|
||||
flow,
|
||||
"_collapse_to_outcome",
|
||||
side_effect=["approved", "publish"],
|
||||
),
|
||||
):
|
||||
flow.kickoff()
|
||||
|
||||
assert "draft" in execution_order
|
||||
assert "router_listener" in execution_order
|
||||
assert "plain_listener" in execution_order
|
||||
assert "on_publish" in execution_order
|
||||
|
||||
|
||||
class TestStateManagement:
|
||||
"""Tests for state management with human feedback."""
|
||||
|
||||
Reference in New Issue
Block a user