From 6bfc98e960fd6bd78978d9131005d2eb2b5cb7e5 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Wed, 4 Feb 2026 15:40:22 -0500 Subject: [PATCH] refactor: extract hitl to provider pattern * refactor: extract hitl to provider pattern - add humaninputprovider protocol with setup_messages and handle_feedback - move sync hitl logic from executor to synchuman inputprovider - add _passthrough_exceptions extension point in agent/core.py - create crewai.core.providers module for extensible components - remove _ask_human_input from base_agent_executor_mixin --- lib/crewai/src/crewai/agent/core.py | 6 + .../base_agent_executor_mixin.py | 50 --- .../src/crewai/agents/crew_agent_executor.py | 121 ++----- lib/crewai/src/crewai/core/__init__.py | 1 + .../src/crewai/core/providers/__init__.py | 1 + .../core/providers/content_processor.py | 78 +++++ .../src/crewai/core/providers/human_input.py | 304 ++++++++++++++++++ lib/crewai/src/crewai/task.py | 20 +- lib/crewai/tests/agents/test_agent.py | 18 +- .../test_flow_human_input_integration.py | 57 ++-- 10 files changed, 465 insertions(+), 191 deletions(-) create mode 100644 lib/crewai/src/crewai/core/__init__.py create mode 100644 lib/crewai/src/crewai/core/providers/__init__.py create mode 100644 lib/crewai/src/crewai/core/providers/content_processor.py create mode 100644 lib/crewai/src/crewai/core/providers/human_input.py diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 6c2626a28..47eb841b4 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -118,6 +118,8 @@ MCP_TOOL_EXECUTION_TIMEOUT: Final[int] = 30 MCP_DISCOVERY_TIMEOUT: Final[int] = 15 MCP_MAX_RETRIES: Final[int] = 3 +_passthrough_exceptions: tuple[type[Exception], ...] = () + # Simple in-memory cache for MCP tool schemas (duration: 5 minutes) _mcp_schema_cache: dict[str, Any] = {} _cache_ttl: Final[int] = 300 # 5 minutes @@ -479,6 +481,8 @@ class Agent(BaseAgent): ), ) raise e + if isinstance(e, _passthrough_exceptions): + raise self._times_executed += 1 if self._times_executed > self.max_retry_limit: crewai_event_bus.emit( @@ -711,6 +715,8 @@ class Agent(BaseAgent): ), ) raise e + if isinstance(e, _passthrough_exceptions): + raise self._times_executed += 1 if self._times_executed > self.max_retry_limit: crewai_event_bus.emit( diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py index c9dceaa84..03787c802 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor_mixin.py @@ -4,7 +4,6 @@ import time from typing import TYPE_CHECKING from crewai.agents.parser import AgentFinish -from crewai.events.event_listener import event_listener from crewai.memory.entity.entity_memory_item import EntityMemoryItem from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem from crewai.utilities.converter import ConverterError @@ -138,52 +137,3 @@ class CrewAgentExecutorMixin: content="Long term memory is enabled, but entity memory is not enabled. Please configure entity memory or set memory=True to automatically enable it.", color="bold_yellow", ) - - def _ask_human_input(self, final_answer: str) -> str: - """Prompt human input with mode-appropriate messaging. - - Note: The final answer is already displayed via the AgentLogsExecutionEvent - panel, so we only show the feedback prompt here. - """ - from rich.panel import Panel - from rich.text import Text - - formatter = event_listener.formatter - formatter.pause_live_updates() - - try: - # Training mode prompt (single iteration) - if self.crew and getattr(self.crew, "_train", False): - prompt_text = ( - "TRAINING MODE: Provide feedback to improve the agent's performance.\n\n" - "This will be used to train better versions of the agent.\n" - "Please provide detailed feedback about the result quality and reasoning process." - ) - title = "🎓 Training Feedback Required" - # Regular human-in-the-loop prompt (multiple iterations) - else: - prompt_text = ( - "Provide feedback on the Final Result above.\n\n" - "• If you are happy with the result, simply hit Enter without typing anything.\n" - "• Otherwise, provide specific improvement requests.\n" - "• You can provide multiple rounds of feedback until satisfied." - ) - title = "💬 Human Feedback Required" - - content = Text() - content.append(prompt_text, style="yellow") - - prompt_panel = Panel( - content, - title=title, - border_style="yellow", - padding=(1, 2), - ) - formatter.console.print(prompt_panel) - - response = input() - if response.strip() != "": - formatter.console.print("\n[cyan]Processing your feedback...[/cyan]") - return response - finally: - formatter.resume_live_updates() diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index 1218ceae8..c0f24516a 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -19,6 +19,7 @@ from crewai.agents.parser import ( AgentFinish, OutputParserError, ) +from crewai.core.providers.human_input import ExecutorContext, get_provider from crewai.events.event_bus import crewai_event_bus from crewai.events.types.logging_events import ( AgentLogsExecutionEvent, @@ -175,15 +176,16 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): """ return self.llm.supports_stop_words() if self.llm else False - def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]: - """Execute the agent with given inputs. + def _setup_messages(self, inputs: dict[str, Any]) -> None: + """Set up messages for the agent execution. Args: inputs: Input dictionary containing prompt variables. - - Returns: - Dictionary with agent output. """ + provider = get_provider() + if provider.setup_messages(cast(ExecutorContext, cast(object, self))): + return + if "system" in self.prompt: system_prompt = self._format_prompt( cast(str, self.prompt.get("system", "")), inputs @@ -197,6 +199,19 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs) self.messages.append(format_message_for_llm(user_prompt)) + provider.post_setup_messages(cast(ExecutorContext, cast(object, self))) + + def invoke(self, inputs: dict[str, Any]) -> dict[str, Any]: + """Execute the agent with given inputs. + + Args: + inputs: Input dictionary containing prompt variables. + + Returns: + Dictionary with agent output. + """ + self._setup_messages(inputs) + self._inject_multimodal_files(inputs) self._show_start_logs() @@ -970,18 +985,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): Returns: Dictionary with agent output. """ - if "system" in self.prompt: - system_prompt = self._format_prompt( - cast(str, self.prompt.get("system", "")), inputs - ) - user_prompt = self._format_prompt( - cast(str, self.prompt.get("user", "")), inputs - ) - self.messages.append(format_message_for_llm(system_prompt, role="system")) - self.messages.append(format_message_for_llm(user_prompt)) - else: - user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs) - self.messages.append(format_message_for_llm(user_prompt)) + self._setup_messages(inputs) await self._ainject_multimodal_files(inputs) @@ -1491,7 +1495,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): return prompt.replace("{tools}", inputs["tools"]) def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish: - """Process human feedback. + """Process human feedback via the configured provider. Args: formatted_answer: Initial agent result. @@ -1499,17 +1503,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): Returns: Final answer after feedback. """ - output_str = ( - formatted_answer.output - if isinstance(formatted_answer.output, str) - else formatted_answer.output.model_dump_json() - ) - human_feedback = self._ask_human_input(output_str) - - if self._is_training_mode(): - return self._handle_training_feedback(formatted_answer, human_feedback) - - return self._handle_regular_feedback(formatted_answer, human_feedback) + provider = get_provider() + return provider.handle_feedback(formatted_answer, self) def _is_training_mode(self) -> bool: """Check if training mode is active. @@ -1519,74 +1514,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): """ return bool(self.crew and self.crew._train) - def _handle_training_feedback( - self, initial_answer: AgentFinish, feedback: str - ) -> AgentFinish: - """Process training feedback. + def _format_feedback_message(self, feedback: str) -> LLMMessage: + """Format feedback as a message for the LLM. Args: - initial_answer: Initial agent output. - feedback: Training feedback. + feedback: User feedback string. Returns: - Improved answer. + Formatted message dict. """ - self._handle_crew_training_output(initial_answer, feedback) - self.messages.append( - format_message_for_llm( - self._i18n.slice("feedback_instructions").format(feedback=feedback) - ) + return format_message_for_llm( + self._i18n.slice("feedback_instructions").format(feedback=feedback) ) - improved_answer = self._invoke_loop() - self._handle_crew_training_output(improved_answer) - self.ask_for_human_input = False - return improved_answer - - def _handle_regular_feedback( - self, current_answer: AgentFinish, initial_feedback: str - ) -> AgentFinish: - """Process regular feedback iteratively. - - Args: - current_answer: Current agent output. - initial_feedback: Initial user feedback. - - Returns: - Final answer after iterations. - """ - feedback = initial_feedback - answer = current_answer - - while self.ask_for_human_input: - # If the user provides a blank response, assume they are happy with the result - if feedback.strip() == "": - self.ask_for_human_input = False - else: - answer = self._process_feedback_iteration(feedback) - output_str = ( - answer.output - if isinstance(answer.output, str) - else answer.output.model_dump_json() - ) - feedback = self._ask_human_input(output_str) - - return answer - - def _process_feedback_iteration(self, feedback: str) -> AgentFinish: - """Process single feedback iteration. - - Args: - feedback: User feedback. - - Returns: - Updated agent response. - """ - self.messages.append( - format_message_for_llm( - self._i18n.slice("feedback_instructions").format(feedback=feedback) - ) - ) - return self._invoke_loop() @classmethod def __get_pydantic_core_schema__( diff --git a/lib/crewai/src/crewai/core/__init__.py b/lib/crewai/src/crewai/core/__init__.py new file mode 100644 index 000000000..714ed9161 --- /dev/null +++ b/lib/crewai/src/crewai/core/__init__.py @@ -0,0 +1 @@ +"""Core crewAI components and interfaces.""" diff --git a/lib/crewai/src/crewai/core/providers/__init__.py b/lib/crewai/src/crewai/core/providers/__init__.py new file mode 100644 index 000000000..fc0a9ed4b --- /dev/null +++ b/lib/crewai/src/crewai/core/providers/__init__.py @@ -0,0 +1 @@ +"""Provider interfaces for extensible crewAI components.""" diff --git a/lib/crewai/src/crewai/core/providers/content_processor.py b/lib/crewai/src/crewai/core/providers/content_processor.py new file mode 100644 index 000000000..828a7e311 --- /dev/null +++ b/lib/crewai/src/crewai/core/providers/content_processor.py @@ -0,0 +1,78 @@ +"""Content processor provider for extensible content processing.""" + +from contextvars import ContextVar +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class ContentProcessorProvider(Protocol): + """Protocol for content processing during task execution.""" + + def process(self, content: str, context: dict[str, Any] | None = None) -> str: + """Process content before use. + + Args: + content: The content to process. + context: Optional context information. + + Returns: + The processed content. + """ + ... + + +class NoOpContentProcessor: + """Default processor that returns content unchanged.""" + + def process(self, content: str, context: dict[str, Any] | None = None) -> str: + """Return content unchanged. + + Args: + content: The content to process. + context: Optional context information (unused). + + Returns: + The original content unchanged. + """ + return content + + +_content_processor: ContextVar[ContentProcessorProvider | None] = ContextVar( + "_content_processor", default=None +) + +_default_processor = NoOpContentProcessor() + + +def get_processor() -> ContentProcessorProvider: + """Get the current content processor. + + Returns: + The registered content processor or the default no-op processor. + """ + processor = _content_processor.get() + if processor is not None: + return processor + return _default_processor + + +def set_processor(processor: ContentProcessorProvider) -> None: + """Set the content processor for the current context. + + Args: + processor: The content processor to use. + """ + _content_processor.set(processor) + + +def process_content(content: str, context: dict[str, Any] | None = None) -> str: + """Process content using the registered processor. + + Args: + content: The content to process. + context: Optional context information. + + Returns: + The processed content. + """ + return get_processor().process(content, context) diff --git a/lib/crewai/src/crewai/core/providers/human_input.py b/lib/crewai/src/crewai/core/providers/human_input.py new file mode 100644 index 000000000..4062e6bb9 --- /dev/null +++ b/lib/crewai/src/crewai/core/providers/human_input.py @@ -0,0 +1,304 @@ +"""Human input provider for HITL (Human-in-the-Loop) flows.""" + +from __future__ import annotations + +from contextvars import ContextVar, Token +from typing import TYPE_CHECKING, Protocol, runtime_checkable + + +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.agents.parser import AgentFinish + from crewai.crew import Crew + from crewai.llms.base_llm import BaseLLM + from crewai.task import Task + from crewai.utilities.types import LLMMessage + + +class ExecutorContext(Protocol): + """Context interface for human input providers to interact with executor.""" + + task: Task | None + crew: Crew | None + messages: list[LLMMessage] + ask_for_human_input: bool + llm: BaseLLM + agent: Agent + + def _invoke_loop(self) -> AgentFinish: + """Invoke the agent loop and return the result.""" + ... + + def _is_training_mode(self) -> bool: + """Check if training mode is active.""" + ... + + def _handle_crew_training_output( + self, + result: AgentFinish, + human_feedback: str | None = None, + ) -> None: + """Handle training output.""" + ... + + def _format_feedback_message(self, feedback: str) -> LLMMessage: + """Format feedback as a message.""" + ... + + +@runtime_checkable +class HumanInputProvider(Protocol): + """Protocol for human input handling. + + Implementations handle the full feedback flow: + - Sync: prompt user, loop until satisfied + - Async: raise exception for external handling + """ + + def setup_messages(self, context: ExecutorContext) -> bool: + """Set up messages for execution. + + Called before standard message setup. Allows providers to handle + conversation resumption or other custom message initialization. + + Args: + context: Executor context with messages list to modify. + + Returns: + True if messages were set up (skip standard setup), + False to use standard setup. + """ + ... + + def post_setup_messages(self, context: ExecutorContext) -> None: + """Called after standard message setup. + + Allows providers to modify messages after standard setup completes. + Only called when setup_messages returned False. + + Args: + context: Executor context with messages list to modify. + """ + ... + + def handle_feedback( + self, + formatted_answer: AgentFinish, + context: ExecutorContext, + ) -> AgentFinish: + """Handle the full human feedback flow. + + Args: + formatted_answer: The agent's current answer. + context: Executor context for callbacks. + + Returns: + The final answer after feedback processing. + + Raises: + Exception: Async implementations may raise to signal external handling. + """ + ... + + @staticmethod + def _get_output_string(answer: AgentFinish) -> str: + """Extract output string from answer. + + Args: + answer: The agent's finished answer. + + Returns: + String representation of the output. + """ + if isinstance(answer.output, str): + return answer.output + return answer.output.model_dump_json() + + +class SyncHumanInputProvider(HumanInputProvider): + """Default synchronous human input via terminal.""" + + def setup_messages(self, context: ExecutorContext) -> bool: + """Use standard message setup. + + Args: + context: Executor context (unused). + + Returns: + False to use standard setup. + """ + return False + + def post_setup_messages(self, context: ExecutorContext) -> None: + """No-op for sync provider. + + Args: + context: Executor context (unused). + """ + + def handle_feedback( + self, + formatted_answer: AgentFinish, + context: ExecutorContext, + ) -> AgentFinish: + """Handle feedback synchronously with terminal prompts. + + Args: + formatted_answer: The agent's current answer. + context: Executor context for callbacks. + + Returns: + The final answer after feedback processing. + """ + feedback = self._prompt_input(context.crew) + + if context._is_training_mode(): + return self._handle_training_feedback(formatted_answer, feedback, context) + + return self._handle_regular_feedback(formatted_answer, feedback, context) + + @staticmethod + def _handle_training_feedback( + initial_answer: AgentFinish, + feedback: str, + context: ExecutorContext, + ) -> AgentFinish: + """Process training feedback (single iteration). + + Args: + initial_answer: The agent's initial answer. + feedback: Human feedback string. + context: Executor context for callbacks. + + Returns: + Improved answer after processing feedback. + """ + context._handle_crew_training_output(initial_answer, feedback) + context.messages.append(context._format_feedback_message(feedback)) + improved_answer = context._invoke_loop() + context._handle_crew_training_output(improved_answer) + context.ask_for_human_input = False + return improved_answer + + def _handle_regular_feedback( + self, + current_answer: AgentFinish, + initial_feedback: str, + context: ExecutorContext, + ) -> AgentFinish: + """Process regular feedback with iteration loop. + + Args: + current_answer: The agent's current answer. + initial_feedback: Initial human feedback string. + context: Executor context for callbacks. + + Returns: + Final answer after all feedback iterations. + """ + feedback = initial_feedback + answer = current_answer + + while context.ask_for_human_input: + if feedback.strip() == "": + context.ask_for_human_input = False + else: + context.messages.append(context._format_feedback_message(feedback)) + answer = context._invoke_loop() + feedback = self._prompt_input(context.crew) + + return answer + + @staticmethod + def _prompt_input(crew: Crew | None) -> str: + """Show rich panel and prompt for input. + + Args: + crew: The crew instance for context. + + Returns: + User input string from terminal. + """ + from rich.panel import Panel + from rich.text import Text + + from crewai.events.event_listener import event_listener + + formatter = event_listener.formatter + formatter.pause_live_updates() + + try: + if crew and getattr(crew, "_train", False): + prompt_text = ( + "TRAINING MODE: Provide feedback to improve the agent's performance.\n\n" + "This will be used to train better versions of the agent.\n" + "Please provide detailed feedback about the result quality and reasoning process." + ) + title = "🎓 Training Feedback Required" + else: + prompt_text = ( + "Provide feedback on the Final Result above.\n\n" + "• If you are happy with the result, simply hit Enter without typing anything.\n" + "• Otherwise, provide specific improvement requests.\n" + "• You can provide multiple rounds of feedback until satisfied." + ) + title = "💬 Human Feedback Required" + + content = Text() + content.append(prompt_text, style="yellow") + + prompt_panel = Panel( + content, + title=title, + border_style="yellow", + padding=(1, 2), + ) + formatter.console.print(prompt_panel) + + response = input() + if response.strip() != "": + formatter.console.print("\n[cyan]Processing your feedback...[/cyan]") + return response + finally: + formatter.resume_live_updates() + + +_provider: ContextVar[HumanInputProvider | None] = ContextVar( + "human_input_provider", + default=None, +) + + +def get_provider() -> HumanInputProvider: + """Get the current human input provider. + + Returns: + The current provider, or a new SyncHumanInputProvider if none set. + """ + provider = _provider.get() + if provider is None: + initialized_provider = SyncHumanInputProvider() + set_provider(initialized_provider) + return initialized_provider + return provider + + +def set_provider(provider: HumanInputProvider) -> Token[HumanInputProvider | None]: + """Set the human input provider for the current context. + + Args: + provider: The provider to use. + + Returns: + Token that can be used to reset to previous value. + """ + return _provider.set(provider) + + +def reset_provider(token: Token[HumanInputProvider | None]) -> None: + """Reset the provider to its previous value. + + Args: + token: Token returned from set_provider. + """ + _provider.reset(token) diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 77056f0ca..d73c3d919 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -31,6 +31,7 @@ from pydantic_core import PydanticCustomError from typing_extensions import Self from crewai.agents.agent_builder.base_agent import BaseAgent +from crewai.core.providers.content_processor import process_content from crewai.events.event_bus import crewai_event_bus from crewai.events.types.task_events import ( TaskCompletedEvent, @@ -496,6 +497,7 @@ class Task(BaseModel): tools: list[BaseTool] | None = None, ) -> TaskOutput: """Execute the task synchronously.""" + self.start_time = datetime.datetime.now() return self._execute_core(agent, context, tools) @property @@ -536,6 +538,7 @@ class Task(BaseModel): ) -> None: """Execute the task asynchronously with context handling.""" try: + self.start_time = datetime.datetime.now() result = self._execute_core(agent, context, tools) future.set_result(result) except Exception as e: @@ -548,6 +551,7 @@ class Task(BaseModel): tools: list[BaseTool] | None = None, ) -> TaskOutput: """Execute the task asynchronously using native async/await.""" + self.start_time = datetime.datetime.now() return await self._aexecute_core(agent, context, tools) async def _aexecute_core( @@ -566,8 +570,6 @@ class Task(BaseModel): f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical." ) - self.start_time = datetime.datetime.now() - self.prompt_context = context tools = tools or self.tools or [] @@ -579,6 +581,8 @@ class Task(BaseModel): tools=tools, ) + self._post_agent_execution(agent) + if not self._guardrails and not self._guardrail: pydantic_output, json_output = self._export_output(result) else: @@ -661,8 +665,6 @@ class Task(BaseModel): f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical." ) - self.start_time = datetime.datetime.now() - self.prompt_context = context tools = tools or self.tools or [] @@ -674,6 +676,8 @@ class Task(BaseModel): tools=tools, ) + self._post_agent_execution(agent) + if not self._guardrails and not self._guardrail: pydantic_output, json_output = self._export_output(result) else: @@ -741,6 +745,9 @@ class Task(BaseModel): finally: clear_task_files(self.id) + def _post_agent_execution(self, agent: BaseAgent) -> None: + pass + def prompt(self) -> str: """Generates the task prompt with optional markdown formatting. @@ -863,6 +870,11 @@ Follow these guidelines: except ValueError as e: raise ValueError(f"Error interpolating description: {e!s}") from e + self.description = process_content(self.description, {"task": self}) + self._original_expected_output = process_content( + self._original_expected_output, {"task": self} + ) + try: self.expected_output = interpolate_only( input_string=self._original_expected_output, inputs=inputs diff --git a/lib/crewai/tests/agents/test_agent.py b/lib/crewai/tests/agents/test_agent.py index 32130f900..025bfd334 100644 --- a/lib/crewai/tests/agents/test_agent.py +++ b/lib/crewai/tests/agents/test_agent.py @@ -703,6 +703,8 @@ def test_agent_definition_based_on_dict(): # test for human input @pytest.mark.vcr() def test_agent_human_input(): + from crewai.core.providers.human_input import SyncHumanInputProvider + # Agent configuration config = { "role": "test role", @@ -720,7 +722,7 @@ def test_agent_human_input(): human_input=True, ) - # Side effect function for _ask_human_input to simulate multiple feedback iterations + # Side effect function for _prompt_input to simulate multiple feedback iterations feedback_responses = iter( [ "Don't say hi, say Hello instead!", # First feedback: instruct change @@ -728,16 +730,16 @@ def test_agent_human_input(): ] ) - def ask_human_input_side_effect(*args, **kwargs): + def prompt_input_side_effect(*args, **kwargs): return next(feedback_responses) - # Patch both _ask_human_input and _invoke_loop to avoid real API/network calls. + # Patch both _prompt_input on provider and _invoke_loop to avoid real API/network calls. with ( patch.object( - CrewAgentExecutor, - "_ask_human_input", - side_effect=ask_human_input_side_effect, - ) as mock_human_input, + SyncHumanInputProvider, + "_prompt_input", + side_effect=prompt_input_side_effect, + ) as mock_prompt_input, patch.object( CrewAgentExecutor, "_invoke_loop", @@ -749,7 +751,7 @@ def test_agent_human_input(): # Assertions to ensure the agent behaves correctly. # It should have requested feedback twice. - assert mock_human_input.call_count == 2 + assert mock_prompt_input.call_count == 2 # The final result should be processed to "Hello" assert output.strip().lower() == "hello" diff --git a/lib/crewai/tests/test_flow_human_input_integration.py b/lib/crewai/tests/test_flow_human_input_integration.py index e60cfe514..3ce4ebbd7 100644 --- a/lib/crewai/tests/test_flow_human_input_integration.py +++ b/lib/crewai/tests/test_flow_human_input_integration.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from crewai.events.event_listener import event_listener +from crewai.core.providers.human_input import SyncHumanInputProvider class TestFlowHumanInputIntegration: @@ -24,14 +25,9 @@ class TestFlowHumanInputIntegration: @patch("builtins.input", return_value="") def test_human_input_pauses_flow_updates(self, mock_input): """Test that human input pauses Flow status updates.""" - from crewai.agents.agent_builder.base_agent_executor_mixin import ( - CrewAgentExecutorMixin, - ) - - executor = CrewAgentExecutorMixin() - executor.crew = MagicMock() - executor.crew._train = False - executor._printer = MagicMock() + provider = SyncHumanInputProvider() + crew = MagicMock() + crew._train = False formatter = event_listener.formatter @@ -39,7 +35,7 @@ class TestFlowHumanInputIntegration: patch.object(formatter, "pause_live_updates") as mock_pause, patch.object(formatter, "resume_live_updates") as mock_resume, ): - result = executor._ask_human_input("Test result") + result = provider._prompt_input(crew) mock_pause.assert_called_once() mock_resume.assert_called_once() @@ -49,14 +45,9 @@ class TestFlowHumanInputIntegration: @patch("builtins.input", side_effect=["feedback", ""]) def test_multiple_human_input_rounds(self, mock_input): """Test multiple rounds of human input with Flow status management.""" - from crewai.agents.agent_builder.base_agent_executor_mixin import ( - CrewAgentExecutorMixin, - ) - - executor = CrewAgentExecutorMixin() - executor.crew = MagicMock() - executor.crew._train = False - executor._printer = MagicMock() + provider = SyncHumanInputProvider() + crew = MagicMock() + crew._train = False formatter = event_listener.formatter @@ -75,10 +66,10 @@ class TestFlowHumanInputIntegration: formatter, "resume_live_updates", side_effect=track_resume ), ): - result1 = executor._ask_human_input("Test result 1") + result1 = provider._prompt_input(crew) assert result1 == "feedback" - result2 = executor._ask_human_input("Test result 2") + result2 = provider._prompt_input(crew) assert result2 == "" assert len(pause_calls) == 2 @@ -103,14 +94,9 @@ class TestFlowHumanInputIntegration: def test_pause_resume_exception_handling(self): """Test that resume is called even if exception occurs during human input.""" - from crewai.agents.agent_builder.base_agent_executor_mixin import ( - CrewAgentExecutorMixin, - ) - - executor = CrewAgentExecutorMixin() - executor.crew = MagicMock() - executor.crew._train = False - executor._printer = MagicMock() + provider = SyncHumanInputProvider() + crew = MagicMock() + crew._train = False formatter = event_listener.formatter @@ -122,21 +108,16 @@ class TestFlowHumanInputIntegration: ), ): with pytest.raises(KeyboardInterrupt): - executor._ask_human_input("Test result") + provider._prompt_input(crew) mock_pause.assert_called_once() mock_resume.assert_called_once() def test_training_mode_human_input(self): """Test human input in training mode.""" - from crewai.agents.agent_builder.base_agent_executor_mixin import ( - CrewAgentExecutorMixin, - ) - - executor = CrewAgentExecutorMixin() - executor.crew = MagicMock() - executor.crew._train = True - executor._printer = MagicMock() + provider = SyncHumanInputProvider() + crew = MagicMock() + crew._train = True formatter = event_listener.formatter @@ -146,7 +127,7 @@ class TestFlowHumanInputIntegration: patch.object(formatter.console, "print") as mock_console_print, patch("builtins.input", return_value="training feedback"), ): - result = executor._ask_human_input("Test result") + result = provider._prompt_input(crew) mock_pause.assert_called_once() mock_resume.assert_called_once() @@ -161,4 +142,4 @@ class TestFlowHumanInputIntegration: for call in call_args if call[0] ) - assert training_panel_found + assert training_panel_found \ No newline at end of file