feat: enhance HumanTool with validation, timeout, and async support

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-11 11:16:46 +00:00
parent ae4ca7748c
commit 097fac6c87
2 changed files with 116 additions and 12 deletions

View File

@@ -1,36 +1,99 @@
"""Tool for handling human input using LangGraph's interrupt mechanism."""
import typing
from typing import Any, Dict, Optional
from pydantic import Field
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
class HumanToolSchema(BaseModel):
"""Schema for HumanTool input validation."""
query: str = Field(
...,
description="The question to ask the user. Must be a non-empty string."
)
timeout: Optional[float] = Field(
default=None,
description="Optional timeout in seconds for waiting for user response"
)
import logging
class HumanTool(BaseTool):
"""Tool for getting human input using LangGraph's interrupt mechanism."""
"""Tool for getting human input using LangGraph's interrupt mechanism.
This tool allows agents to request input from users through LangGraph's
interrupt mechanism. It supports timeout configuration and input validation.
"""
name: str = "human"
description: str = "Useful to ask user to enter input."
args_schema: type[BaseModel] = HumanToolSchema
result_as_answer: bool = False # Don't use the response as final answer
def _run(self, query: str) -> str:
def _run(self, query: str, timeout: Optional[float] = None) -> str:
"""Execute the human input tool.
Args:
query: The question to ask the user
timeout: Optional timeout in seconds
Returns:
The user's response
Raises:
ImportError: If LangGraph is not installed
TimeoutError: If response times out
ValueError: If query is invalid
"""
if not query or not isinstance(query, str):
raise ValueError("Query must be a non-empty string")
try:
from langgraph.prebuilt.state_graphs import interrupt
human_response = interrupt({"query": query})
logging.info(f"Requesting human input: {query}")
human_response = interrupt({"query": query, "timeout": timeout})
return human_response["data"]
except ImportError:
logging.error("LangGraph not installed")
raise ImportError(
"LangGraph is required for HumanTool. "
"Install with `pip install langgraph`"
)
except Exception as e:
logging.error(f"Error during human input: {str(e)}")
raise
async def _arun(self, query: str, timeout: Optional[float] = None) -> str:
"""Execute the human input tool asynchronously.
Args:
query: The question to ask the user
timeout: Optional timeout in seconds
Returns:
The user's response
Raises:
ImportError: If LangGraph is not installed
TimeoutError: If response times out
ValueError: If query is invalid
"""
if not query or not isinstance(query, str):
raise ValueError("Query must be a non-empty string")
try:
from langgraph.prebuilt.state_graphs import interrupt
logging.info(f"Requesting async human input: {query}")
human_response = interrupt({"query": query, "timeout": timeout})
return human_response["data"]
except ImportError:
logging.error("LangGraph not installed")
raise ImportError(
"LangGraph is required for HumanTool. "
"Install with `pip install langgraph`"
)
except Exception as e:
logging.error(f"Error during async human input: {str(e)}")
raise

View File

@@ -17,19 +17,60 @@ def test_human_tool_with_langgraph_interrupt():
"""Test HumanTool with LangGraph interrupt handling."""
tool = HumanTool()
# Test successful interrupt handling
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = tool._run("test query")
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query"})
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
# Test interrupt propagation
def test_human_tool_timeout():
"""Test HumanTool timeout handling."""
tool = HumanTool()
timeout = 30.0
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.side_effect = Exception("Interrupt")
with pytest.raises(Exception) as exc_info:
tool._run("test query")
assert "Interrupt" in str(exc_info.value)
mock_interrupt.return_value = {"data": "test response"}
result = tool._run("test query", timeout=timeout)
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
def test_human_tool_invalid_input():
"""Test HumanTool input validation."""
tool = HumanTool()
with pytest.raises(ValueError, match="Query must be a non-empty string"):
tool._run("")
with pytest.raises(ValueError, match="Query must be a non-empty string"):
tool._run(None)
@pytest.mark.asyncio
async def test_human_tool_async():
"""Test async HumanTool functionality."""
tool = HumanTool()
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = await tool._arun("test query")
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
@pytest.mark.asyncio
async def test_human_tool_async_timeout():
"""Test async HumanTool timeout handling."""
tool = HumanTool()
timeout = 30.0
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
mock_interrupt.return_value = {"data": "test response"}
result = await tool._arun("test query", timeout=timeout)
assert result == "test response"
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
def test_human_tool_without_langgraph():
"""Test HumanTool behavior when LangGraph is not installed."""