diff --git a/src/crewai/tools/__init__.py b/src/crewai/tools/__init__.py index 41819ccbc..51f5a51cd 100644 --- a/src/crewai/tools/__init__.py +++ b/src/crewai/tools/__init__.py @@ -1 +1,2 @@ from .base_tool import BaseTool, tool +from .human_tool import HumanTool diff --git a/src/crewai/tools/human_tool.py b/src/crewai/tools/human_tool.py new file mode 100644 index 000000000..8e4cb00f5 --- /dev/null +++ b/src/crewai/tools/human_tool.py @@ -0,0 +1,35 @@ +"""Tool for handling human input using LangGraph's interrupt mechanism.""" + +from typing import Any, Dict +from pydantic import Field + +from crewai.tools import BaseTool + +class HumanTool(BaseTool): + """Tool for getting human input using LangGraph's interrupt mechanism.""" + + name: str = "human" + description: str = "Useful to ask user to enter input." + result_as_answer: bool = False # Don't use the response as final answer + + def _run(self, query: str) -> str: + """Execute the human input tool. + + Args: + query: The question to ask the user + + Returns: + The user's response + + Raises: + ImportError: If LangGraph is not installed + """ + try: + from langgraph.prebuilt.state_graphs import interrupt + human_response = interrupt({"query": query}) + return human_response["data"] + except ImportError: + raise ImportError( + "LangGraph is required for HumanTool. " + "Install with `pip install langgraph`" + ) diff --git a/src/crewai/tools/tool_usage.py b/src/crewai/tools/tool_usage.py index 532587ced..f1d8900f8 100644 --- a/src/crewai/tools/tool_usage.py +++ b/src/crewai/tools/tool_usage.py @@ -182,6 +182,10 @@ class ToolUsage: else: result = tool.invoke(input={}) except Exception as e: + # Check if this is a LangGraph interrupt that should be propagated + if hasattr(e, '__class__') and e.__class__.__name__ == 'Interrupt': + raise e # Propagate interrupt up + self.on_tool_error(tool=tool, tool_calling=calling, e=e) self._run_attempts += 1 if self._run_attempts > self._max_parsing_attempts: diff --git a/tests/tools/test_human_tool.py b/tests/tools/test_human_tool.py new file mode 100644 index 000000000..dbdd7529e --- /dev/null +++ b/tests/tools/test_human_tool.py @@ -0,0 +1,42 @@ +"""Test HumanTool functionality.""" + +import pytest +from unittest.mock import patch + +from crewai.tools import HumanTool + +def test_human_tool_basic(): + """Test basic HumanTool creation and attributes.""" + tool = HumanTool() + assert tool.name == "human" + assert "ask user to enter input" in tool.description.lower() + assert not tool.result_as_answer + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_human_tool_with_langgraph_interrupt(): + """Test HumanTool with LangGraph interrupt handling.""" + tool = HumanTool() + + # Test successful interrupt handling + with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt: + mock_interrupt.return_value = {"data": "test response"} + result = tool._run("test query") + assert result == "test response" + mock_interrupt.assert_called_with({"query": "test query"}) + + # Test interrupt propagation + with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt: + mock_interrupt.side_effect = Exception("Interrupt") + with pytest.raises(Exception) as exc_info: + tool._run("test query") + assert "Interrupt" in str(exc_info.value) + +def test_human_tool_without_langgraph(): + """Test HumanTool behavior when LangGraph is not installed.""" + tool = HumanTool() + + with patch.dict('sys.modules', {'langgraph': None}): + with pytest.raises(ImportError) as exc_info: + tool._run("test query") + assert "LangGraph is required" in str(exc_info.value) + assert "pip install langgraph" in str(exc_info.value) diff --git a/tests/tools/test_tool_usage.py b/tests/tools/test_tool_usage.py index 05b9b23af..a0f364fac 100644 --- a/tests/tools/test_tool_usage.py +++ b/tests/tools/test_tool_usage.py @@ -85,6 +85,38 @@ def test_random_number_tool_schema(): ) +def test_tool_usage_interrupt_handling(): + """Test that tool usage properly propagates LangGraph interrupts.""" + from unittest.mock import patch, MagicMock + + class InterruptingTool(BaseTool): + name: str = "interrupt_test" + description: str = "A tool that raises LangGraph interrupts" + + def _run(self, query: str) -> str: + raise type('Interrupt', (Exception,), {})("test interrupt") + + tool = InterruptingTool() + tool_usage = ToolUsage( + tools_handler=MagicMock(), + tools=[tool], + original_tools=[tool], + tools_description="Sample tool for testing", + tools_names="interrupt_test", + task=MagicMock(), + function_calling_llm=MagicMock(), + agent=MagicMock(), + action=MagicMock(), + ) + + # Test that interrupt is propagated + with pytest.raises(Exception) as exc_info: + tool_usage.use( + ToolCalling(tool_name="interrupt_test", arguments={"query": "test"}, log="test"), + "test" + ) + assert "test interrupt" in str(exc_info.value) + def test_tool_usage_render(): tool = RandomNumberTool()