mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
fix: handle LangGraph interrupts in human tool
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1 +1,2 @@
|
|||||||
from .base_tool import BaseTool, tool
|
from .base_tool import BaseTool, tool
|
||||||
|
from .human_tool import HumanTool
|
||||||
|
|||||||
35
src/crewai/tools/human_tool.py
Normal file
35
src/crewai/tools/human_tool.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""Tool for handling human input using LangGraph's interrupt mechanism."""
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
class HumanTool(BaseTool):
|
||||||
|
"""Tool for getting human input using LangGraph's interrupt mechanism."""
|
||||||
|
|
||||||
|
name: str = "human"
|
||||||
|
description: str = "Useful to ask user to enter input."
|
||||||
|
result_as_answer: bool = False # Don't use the response as final answer
|
||||||
|
|
||||||
|
def _run(self, query: str) -> str:
|
||||||
|
"""Execute the human input tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The question to ask the user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The user's response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If LangGraph is not installed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from langgraph.prebuilt.state_graphs import interrupt
|
||||||
|
human_response = interrupt({"query": query})
|
||||||
|
return human_response["data"]
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"LangGraph is required for HumanTool. "
|
||||||
|
"Install with `pip install langgraph`"
|
||||||
|
)
|
||||||
@@ -182,6 +182,10 @@ class ToolUsage:
|
|||||||
else:
|
else:
|
||||||
result = tool.invoke(input={})
|
result = tool.invoke(input={})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Check if this is a LangGraph interrupt that should be propagated
|
||||||
|
if hasattr(e, '__class__') and e.__class__.__name__ == 'Interrupt':
|
||||||
|
raise e # Propagate interrupt up
|
||||||
|
|
||||||
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
||||||
self._run_attempts += 1
|
self._run_attempts += 1
|
||||||
if self._run_attempts > self._max_parsing_attempts:
|
if self._run_attempts > self._max_parsing_attempts:
|
||||||
|
|||||||
42
tests/tools/test_human_tool.py
Normal file
42
tests/tools/test_human_tool.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""Test HumanTool functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from crewai.tools import HumanTool
|
||||||
|
|
||||||
|
def test_human_tool_basic():
|
||||||
|
"""Test basic HumanTool creation and attributes."""
|
||||||
|
tool = HumanTool()
|
||||||
|
assert tool.name == "human"
|
||||||
|
assert "ask user to enter input" in tool.description.lower()
|
||||||
|
assert not tool.result_as_answer
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_human_tool_with_langgraph_interrupt():
|
||||||
|
"""Test HumanTool with LangGraph interrupt handling."""
|
||||||
|
tool = HumanTool()
|
||||||
|
|
||||||
|
# Test successful interrupt handling
|
||||||
|
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||||
|
mock_interrupt.return_value = {"data": "test response"}
|
||||||
|
result = tool._run("test query")
|
||||||
|
assert result == "test response"
|
||||||
|
mock_interrupt.assert_called_with({"query": "test query"})
|
||||||
|
|
||||||
|
# Test interrupt propagation
|
||||||
|
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||||
|
mock_interrupt.side_effect = Exception("Interrupt")
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
tool._run("test query")
|
||||||
|
assert "Interrupt" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_human_tool_without_langgraph():
|
||||||
|
"""Test HumanTool behavior when LangGraph is not installed."""
|
||||||
|
tool = HumanTool()
|
||||||
|
|
||||||
|
with patch.dict('sys.modules', {'langgraph': None}):
|
||||||
|
with pytest.raises(ImportError) as exc_info:
|
||||||
|
tool._run("test query")
|
||||||
|
assert "LangGraph is required" in str(exc_info.value)
|
||||||
|
assert "pip install langgraph" in str(exc_info.value)
|
||||||
@@ -85,6 +85,38 @@ def test_random_number_tool_schema():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_usage_interrupt_handling():
|
||||||
|
"""Test that tool usage properly propagates LangGraph interrupts."""
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
class InterruptingTool(BaseTool):
|
||||||
|
name: str = "interrupt_test"
|
||||||
|
description: str = "A tool that raises LangGraph interrupts"
|
||||||
|
|
||||||
|
def _run(self, query: str) -> str:
|
||||||
|
raise type('Interrupt', (Exception,), {})("test interrupt")
|
||||||
|
|
||||||
|
tool = InterruptingTool()
|
||||||
|
tool_usage = ToolUsage(
|
||||||
|
tools_handler=MagicMock(),
|
||||||
|
tools=[tool],
|
||||||
|
original_tools=[tool],
|
||||||
|
tools_description="Sample tool for testing",
|
||||||
|
tools_names="interrupt_test",
|
||||||
|
task=MagicMock(),
|
||||||
|
function_calling_llm=MagicMock(),
|
||||||
|
agent=MagicMock(),
|
||||||
|
action=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test that interrupt is propagated
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
tool_usage.use(
|
||||||
|
ToolCalling(tool_name="interrupt_test", arguments={"query": "test"}, log="test"),
|
||||||
|
"test"
|
||||||
|
)
|
||||||
|
assert "test interrupt" in str(exc_info.value)
|
||||||
|
|
||||||
def test_tool_usage_render():
|
def test_tool_usage_render():
|
||||||
tool = RandomNumberTool()
|
tool = RandomNumberTool()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user