mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: enhance HumanTool with validation, timeout, and async support
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,36 +1,99 @@
|
|||||||
"""Tool for handling human input using LangGraph's interrupt mechanism."""
|
"""Tool for handling human input using LangGraph's interrupt mechanism."""
|
||||||
|
|
||||||
import typing
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class HumanToolSchema(BaseModel):
|
||||||
|
"""Schema for HumanTool input validation."""
|
||||||
|
query: str = Field(
|
||||||
|
...,
|
||||||
|
description="The question to ask the user. Must be a non-empty string."
|
||||||
|
)
|
||||||
|
timeout: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional timeout in seconds for waiting for user response"
|
||||||
|
)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
class HumanTool(BaseTool):
|
class HumanTool(BaseTool):
|
||||||
"""Tool for getting human input using LangGraph's interrupt mechanism."""
|
"""Tool for getting human input using LangGraph's interrupt mechanism.
|
||||||
|
|
||||||
|
This tool allows agents to request input from users through LangGraph's
|
||||||
|
interrupt mechanism. It supports timeout configuration and input validation.
|
||||||
|
"""
|
||||||
|
|
||||||
name: str = "human"
|
name: str = "human"
|
||||||
description: str = "Useful to ask user to enter input."
|
description: str = "Useful to ask user to enter input."
|
||||||
|
args_schema: type[BaseModel] = HumanToolSchema
|
||||||
result_as_answer: bool = False # Don't use the response as final answer
|
result_as_answer: bool = False # Don't use the response as final answer
|
||||||
|
|
||||||
def _run(self, query: str) -> str:
|
def _run(self, query: str, timeout: Optional[float] = None) -> str:
|
||||||
"""Execute the human input tool.
|
"""Execute the human input tool.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The question to ask the user
|
query: The question to ask the user
|
||||||
|
timeout: Optional timeout in seconds
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The user's response
|
The user's response
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ImportError: If LangGraph is not installed
|
ImportError: If LangGraph is not installed
|
||||||
|
TimeoutError: If response times out
|
||||||
|
ValueError: If query is invalid
|
||||||
"""
|
"""
|
||||||
|
if not query or not isinstance(query, str):
|
||||||
|
raise ValueError("Query must be a non-empty string")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langgraph.prebuilt.state_graphs import interrupt
|
from langgraph.prebuilt.state_graphs import interrupt
|
||||||
human_response = interrupt({"query": query})
|
logging.info(f"Requesting human input: {query}")
|
||||||
|
human_response = interrupt({"query": query, "timeout": timeout})
|
||||||
return human_response["data"]
|
return human_response["data"]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
logging.error("LangGraph not installed")
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"LangGraph is required for HumanTool. "
|
"LangGraph is required for HumanTool. "
|
||||||
"Install with `pip install langgraph`"
|
"Install with `pip install langgraph`"
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error during human input: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _arun(self, query: str, timeout: Optional[float] = None) -> str:
|
||||||
|
"""Execute the human input tool asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The question to ask the user
|
||||||
|
timeout: Optional timeout in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The user's response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If LangGraph is not installed
|
||||||
|
TimeoutError: If response times out
|
||||||
|
ValueError: If query is invalid
|
||||||
|
"""
|
||||||
|
if not query or not isinstance(query, str):
|
||||||
|
raise ValueError("Query must be a non-empty string")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langgraph.prebuilt.state_graphs import interrupt
|
||||||
|
logging.info(f"Requesting async human input: {query}")
|
||||||
|
human_response = interrupt({"query": query, "timeout": timeout})
|
||||||
|
return human_response["data"]
|
||||||
|
except ImportError:
|
||||||
|
logging.error("LangGraph not installed")
|
||||||
|
raise ImportError(
|
||||||
|
"LangGraph is required for HumanTool. "
|
||||||
|
"Install with `pip install langgraph`"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error during async human input: {str(e)}")
|
||||||
|
raise
|
||||||
|
|||||||
@@ -17,19 +17,60 @@ def test_human_tool_with_langgraph_interrupt():
|
|||||||
"""Test HumanTool with LangGraph interrupt handling."""
|
"""Test HumanTool with LangGraph interrupt handling."""
|
||||||
tool = HumanTool()
|
tool = HumanTool()
|
||||||
|
|
||||||
# Test successful interrupt handling
|
|
||||||
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||||
mock_interrupt.return_value = {"data": "test response"}
|
mock_interrupt.return_value = {"data": "test response"}
|
||||||
result = tool._run("test query")
|
result = tool._run("test query")
|
||||||
assert result == "test response"
|
assert result == "test response"
|
||||||
mock_interrupt.assert_called_with({"query": "test query"})
|
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
|
||||||
|
|
||||||
# Test interrupt propagation
|
|
||||||
|
def test_human_tool_timeout():
|
||||||
|
"""Test HumanTool timeout handling."""
|
||||||
|
tool = HumanTool()
|
||||||
|
timeout = 30.0
|
||||||
|
|
||||||
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||||
mock_interrupt.side_effect = Exception("Interrupt")
|
mock_interrupt.return_value = {"data": "test response"}
|
||||||
with pytest.raises(Exception) as exc_info:
|
result = tool._run("test query", timeout=timeout)
|
||||||
tool._run("test query")
|
assert result == "test response"
|
||||||
assert "Interrupt" in str(exc_info.value)
|
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_tool_invalid_input():
|
||||||
|
"""Test HumanTool input validation."""
|
||||||
|
tool = HumanTool()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Query must be a non-empty string"):
|
||||||
|
tool._run("")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Query must be a non-empty string"):
|
||||||
|
tool._run(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_human_tool_async():
|
||||||
|
"""Test async HumanTool functionality."""
|
||||||
|
tool = HumanTool()
|
||||||
|
|
||||||
|
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = {"data": "test response"}
|
||||||
|
result = await tool._arun("test query")
|
||||||
|
assert result == "test response"
|
||||||
|
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_human_tool_async_timeout():
|
||||||
|
"""Test async HumanTool timeout handling."""
|
||||||
|
tool = HumanTool()
|
||||||
|
timeout = 30.0
|
||||||
|
|
||||||
|
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = {"data": "test response"}
|
||||||
|
result = await tool._arun("test query", timeout=timeout)
|
||||||
|
assert result == "test response"
|
||||||
|
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
|
||||||
|
|
||||||
|
|
||||||
def test_human_tool_without_langgraph():
|
def test_human_tool_without_langgraph():
|
||||||
"""Test HumanTool behavior when LangGraph is not installed."""
|
"""Test HumanTool behavior when LangGraph is not installed."""
|
||||||
|
|||||||
Reference in New Issue
Block a user