mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-06 06:38:29 +00:00
Compare commits
10 Commits
fix/issue-
...
devin/1739
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
18a38ba436 | ||
|
|
369ee46ff3 | ||
|
|
39a290b4d3 | ||
|
|
d2cc61028f | ||
|
|
edcd55d19f | ||
|
|
097fac6c87 | ||
|
|
ae4ca7748c | ||
|
|
8b58feb5e0 | ||
|
|
a4856a9805 | ||
|
|
364a31ca8b |
@@ -771,65 +771,6 @@ class Crew(BaseModel):
|
||||
|
||||
return self._create_crew_output(task_outputs)
|
||||
|
||||
def _get_context_based_output(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: List[TaskOutput],
|
||||
task_index: int,
|
||||
) -> Optional[TaskOutput]:
|
||||
"""Get the output from explicit context tasks."""
|
||||
context_task_outputs = []
|
||||
for context_task in task.context:
|
||||
context_task_index = self._find_task_index(context_task)
|
||||
if context_task_index != -1 and context_task_index < task_index:
|
||||
for output in task_outputs:
|
||||
if output.description == context_task.description:
|
||||
context_task_outputs.append(output)
|
||||
break
|
||||
return context_task_outputs[-1] if context_task_outputs else None
|
||||
|
||||
def _get_non_conditional_output(
|
||||
self,
|
||||
task_outputs: List[TaskOutput],
|
||||
task_index: int,
|
||||
) -> Optional[TaskOutput]:
|
||||
"""Get the output from the most recent non-conditional task."""
|
||||
non_conditional_outputs = []
|
||||
for i in range(task_index):
|
||||
if i < len(self.tasks) and not isinstance(self.tasks[i], ConditionalTask):
|
||||
for output in task_outputs:
|
||||
if output.description == self.tasks[i].description:
|
||||
non_conditional_outputs.append(output)
|
||||
break
|
||||
return non_conditional_outputs[-1] if non_conditional_outputs else None
|
||||
|
||||
def _get_previous_output(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: List[TaskOutput],
|
||||
task_index: int,
|
||||
) -> Optional[TaskOutput]:
|
||||
"""Get the previous output for a conditional task.
|
||||
|
||||
The order of precedence is:
|
||||
1. Output from explicit context tasks
|
||||
2. Output from the most recent non-conditional task
|
||||
3. Output from the immediately preceding task
|
||||
"""
|
||||
if task.context and len(task.context) > 0:
|
||||
previous_output = self._get_context_based_output(task, task_outputs, task_index)
|
||||
if previous_output:
|
||||
return previous_output
|
||||
|
||||
previous_output = self._get_non_conditional_output(task_outputs, task_index)
|
||||
if previous_output:
|
||||
return previous_output
|
||||
|
||||
if task_outputs and task_index > 0 and task_index <= len(task_outputs):
|
||||
return task_outputs[task_index - 1]
|
||||
|
||||
return None
|
||||
|
||||
def _handle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
@@ -838,17 +779,11 @@ class Crew(BaseModel):
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> Optional[TaskOutput]:
|
||||
"""Handle a conditional task.
|
||||
|
||||
Determines whether a conditional task should be executed based on the output
|
||||
of previous tasks. If the task should not be executed, returns a skipped task output.
|
||||
"""
|
||||
if futures:
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
futures.clear()
|
||||
|
||||
previous_output = self._get_previous_output(task, task_outputs, task_index)
|
||||
|
||||
previous_output = task_outputs[task_index - 1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
self._logger.log(
|
||||
"debug",
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .base_tool import BaseTool, tool
|
||||
from .human_tool import HumanTool
|
||||
|
||||
98
src/crewai/tools/human_tool.py
Normal file
98
src/crewai/tools/human_tool.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Tool for handling human input using LangGraph's interrupt mechanism."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class HumanToolSchema(BaseModel):
|
||||
"""Schema for HumanTool input validation."""
|
||||
query: str = Field(
|
||||
...,
|
||||
description="The question to ask the user. Must be a non-empty string."
|
||||
)
|
||||
timeout: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Optional timeout in seconds for waiting for user response"
|
||||
)
|
||||
|
||||
class HumanTool(BaseTool):
|
||||
"""Tool for getting human input using LangGraph's interrupt mechanism.
|
||||
|
||||
This tool allows agents to request input from users through LangGraph's
|
||||
interrupt mechanism. It supports timeout configuration and input validation.
|
||||
"""
|
||||
|
||||
name: str = "human"
|
||||
description: str = "Useful to ask user to enter input."
|
||||
args_schema: type[BaseModel] = HumanToolSchema
|
||||
result_as_answer: bool = False # Don't use the response as final answer
|
||||
|
||||
def _run(self, query: str, timeout: Optional[float] = None) -> str:
|
||||
"""Execute the human input tool.
|
||||
|
||||
Args:
|
||||
query: The question to ask the user
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
The user's response
|
||||
|
||||
Raises:
|
||||
ImportError: If LangGraph is not installed
|
||||
TimeoutError: If response times out
|
||||
ValueError: If query is invalid
|
||||
"""
|
||||
if not query or not isinstance(query, str):
|
||||
raise ValueError("Query must be a non-empty string")
|
||||
|
||||
try:
|
||||
from langgraph.prebuilt.state_graphs import interrupt
|
||||
logging.info(f"Requesting human input: {query}")
|
||||
human_response = interrupt({"query": query, "timeout": timeout})
|
||||
return human_response["data"]
|
||||
except ImportError:
|
||||
logging.error("LangGraph not installed")
|
||||
raise ImportError(
|
||||
"LangGraph is required for HumanTool. "
|
||||
"Install with `pip install langgraph`"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during human input: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _arun(self, query: str, timeout: Optional[float] = None) -> str:
|
||||
"""Execute the human input tool asynchronously.
|
||||
|
||||
Args:
|
||||
query: The question to ask the user
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
The user's response
|
||||
|
||||
Raises:
|
||||
ImportError: If LangGraph is not installed
|
||||
TimeoutError: If response times out
|
||||
ValueError: If query is invalid
|
||||
"""
|
||||
if not query or not isinstance(query, str):
|
||||
raise ValueError("Query must be a non-empty string")
|
||||
|
||||
try:
|
||||
from langgraph.prebuilt.state_graphs import interrupt
|
||||
logging.info(f"Requesting async human input: {query}")
|
||||
human_response = interrupt({"query": query, "timeout": timeout})
|
||||
return human_response["data"]
|
||||
except ImportError:
|
||||
logging.error("LangGraph not installed")
|
||||
raise ImportError(
|
||||
"LangGraph is required for HumanTool. "
|
||||
"Install with `pip install langgraph`"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error during async human input: {str(e)}")
|
||||
raise
|
||||
@@ -182,6 +182,10 @@ class ToolUsage:
|
||||
else:
|
||||
result = tool.invoke(input={})
|
||||
except Exception as e:
|
||||
# Check if this is a LangGraph interrupt that should be propagated
|
||||
if hasattr(e, '__class__') and e.__class__.__name__ == 'Interrupt':
|
||||
raise e # Propagate interrupt up
|
||||
|
||||
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
||||
self._run_attempts += 1
|
||||
if self._run_attempts > self._max_parsing_attempts:
|
||||
|
||||
@@ -1,335 +0,0 @@
|
||||
"""Test for multiple conditional tasks."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class TestMultipleConditionalTasks:
|
||||
"""Test class for multiple conditional tasks scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_agents(self):
|
||||
"""Set up agents for the tests."""
|
||||
agent1 = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Find information",
|
||||
backstory="You're a researcher",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
agent2 = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Process information",
|
||||
backstory="You process data",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
agent3 = Agent(
|
||||
role="Report Writer",
|
||||
goal="Write reports",
|
||||
backstory="You write reports",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
return agent1, agent2, agent3
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tasks(self, setup_agents):
|
||||
"""Set up tasks for the tests."""
|
||||
agent1, agent2, agent3 = setup_agents
|
||||
|
||||
# Create tasks
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=agent1,
|
||||
)
|
||||
|
||||
# First conditional task should check task1's output
|
||||
condition1_mock = MagicMock()
|
||||
task2 = ConditionalTask(
|
||||
description="Conditional Task 2",
|
||||
expected_output="Output 2",
|
||||
agent=agent2,
|
||||
condition=condition1_mock,
|
||||
)
|
||||
|
||||
# Second conditional task should check task1's output, not task2's
|
||||
condition2_mock = MagicMock()
|
||||
task3 = ConditionalTask(
|
||||
description="Conditional Task 3",
|
||||
expected_output="Output 3",
|
||||
agent=agent3,
|
||||
condition=condition2_mock,
|
||||
)
|
||||
|
||||
return task1, task2, task3, condition1_mock, condition2_mock
|
||||
|
||||
@pytest.fixture
|
||||
def setup_crew(self, setup_agents, setup_tasks):
|
||||
"""Set up crew for the tests."""
|
||||
agent1, agent2, agent3 = setup_agents
|
||||
task1, task2, task3, _, _ = setup_tasks
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent1, agent2, agent3],
|
||||
tasks=[task1, task2, task3],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
return crew
|
||||
|
||||
@pytest.fixture
|
||||
def setup_task_outputs(self, setup_agents):
|
||||
"""Set up task outputs for the tests."""
|
||||
agent1, agent2, _ = setup_agents
|
||||
|
||||
task1_output = TaskOutput(
|
||||
description="Task 1",
|
||||
raw="Task 1 output",
|
||||
agent=agent1.role,
|
||||
output_format=OutputFormat.RAW,
|
||||
)
|
||||
|
||||
task2_output = TaskOutput(
|
||||
description="Conditional Task 2",
|
||||
raw="Task 2 output",
|
||||
agent=agent2.role,
|
||||
output_format=OutputFormat.RAW,
|
||||
)
|
||||
|
||||
return task1_output, task2_output
|
||||
|
||||
def test_first_conditional_task_execution(self, setup_crew, setup_tasks, setup_task_outputs):
|
||||
"""Test that the first conditional task is evaluated correctly."""
|
||||
crew = setup_crew
|
||||
_, task2, _, condition1_mock, _ = setup_tasks
|
||||
task1_output, _ = setup_task_outputs
|
||||
|
||||
condition1_mock.return_value = True # Task should execute
|
||||
result = crew._handle_conditional_task(
|
||||
task=task2,
|
||||
task_outputs=[task1_output],
|
||||
futures=[],
|
||||
task_index=1,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output
|
||||
condition1_mock.assert_called_once()
|
||||
args = condition1_mock.call_args[0][0]
|
||||
assert args.raw == "Task 1 output"
|
||||
assert result is None # Task should execute, so no skipped output
|
||||
|
||||
def test_second_conditional_task_execution(self, setup_crew, setup_tasks, setup_task_outputs):
|
||||
"""Test that the second conditional task is evaluated correctly."""
|
||||
crew = setup_crew
|
||||
_, _, task3, _, condition2_mock = setup_tasks
|
||||
task1_output, task2_output = setup_task_outputs
|
||||
|
||||
condition2_mock.return_value = True # Task should execute
|
||||
result = crew._handle_conditional_task(
|
||||
task=task3,
|
||||
task_outputs=[task1_output, task2_output],
|
||||
futures=[],
|
||||
task_index=2,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output, not task2's
|
||||
condition2_mock.assert_called_once()
|
||||
args = condition2_mock.call_args[0][0]
|
||||
assert args.raw == "Task 1 output" # Should be task1's output
|
||||
assert args.raw != "Task 2 output" # Should not be task2's output
|
||||
assert result is None # Task should execute, so no skipped output
|
||||
|
||||
def test_conditional_task_skipping(self, setup_crew, setup_tasks, setup_task_outputs):
|
||||
"""Test that conditional tasks are skipped when the condition returns False."""
|
||||
crew = setup_crew
|
||||
_, task2, _, condition1_mock, _ = setup_tasks
|
||||
task1_output, _ = setup_task_outputs
|
||||
|
||||
condition1_mock.return_value = False # Task should be skipped
|
||||
result = crew._handle_conditional_task(
|
||||
task=task2,
|
||||
task_outputs=[task1_output],
|
||||
futures=[],
|
||||
task_index=1,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output
|
||||
condition1_mock.assert_called_once()
|
||||
args = condition1_mock.call_args[0][0]
|
||||
assert args.raw == "Task 1 output"
|
||||
assert result is not None # Task should be skipped, so there should be a skipped output
|
||||
assert result.description == task2.description
|
||||
|
||||
def test_conditional_task_with_explicit_context(self, setup_crew, setup_agents, setup_task_outputs):
|
||||
"""Test conditional task with explicit context tasks."""
|
||||
crew = setup_crew
|
||||
agent1, agent2, _ = setup_agents
|
||||
task1_output, _ = setup_task_outputs
|
||||
|
||||
with patch.object(crew, '_find_task_index', return_value=0):
|
||||
context_task = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=agent1,
|
||||
)
|
||||
|
||||
condition_mock = MagicMock(return_value=True)
|
||||
task_with_context = ConditionalTask(
|
||||
description="Task with Context",
|
||||
expected_output="Output with Context",
|
||||
agent=agent2,
|
||||
condition=condition_mock,
|
||||
context=[context_task],
|
||||
)
|
||||
|
||||
crew.tasks.append(task_with_context)
|
||||
|
||||
result = crew._handle_conditional_task(
|
||||
task=task_with_context,
|
||||
task_outputs=[task1_output],
|
||||
futures=[],
|
||||
task_index=3, # This would be the 4th task
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output
|
||||
condition_mock.assert_called_once()
|
||||
args = condition_mock.call_args[0][0]
|
||||
assert args.raw == "Task 1 output"
|
||||
assert result is None # Task should execute, so no skipped output
|
||||
|
||||
def test_conditional_task_with_empty_task_outputs(self, setup_crew, setup_tasks):
|
||||
"""Test conditional task with empty task outputs."""
|
||||
crew = setup_crew
|
||||
_, task2, _, condition1_mock, _ = setup_tasks
|
||||
|
||||
result = crew._handle_conditional_task(
|
||||
task=task2,
|
||||
task_outputs=[],
|
||||
futures=[],
|
||||
task_index=1,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
condition1_mock.assert_not_called()
|
||||
assert result is None # Task should execute, so no skipped output
|
||||
|
||||
|
||||
def test_multiple_conditional_tasks():
|
||||
"""Test that multiple conditional tasks are evaluated correctly.
|
||||
|
||||
This is a legacy test that's kept for backward compatibility.
|
||||
The actual tests are now in the TestMultipleConditionalTasks class.
|
||||
"""
|
||||
agent1 = Agent(
|
||||
role="Research Analyst",
|
||||
goal="Find information",
|
||||
backstory="You're a researcher",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
agent2 = Agent(
|
||||
role="Data Analyst",
|
||||
goal="Process information",
|
||||
backstory="You process data",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
agent3 = Agent(
|
||||
role="Report Writer",
|
||||
goal="Write reports",
|
||||
backstory="You write reports",
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Create tasks
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1",
|
||||
agent=agent1,
|
||||
)
|
||||
|
||||
# First conditional task should check task1's output
|
||||
condition1_mock = MagicMock()
|
||||
task2 = ConditionalTask(
|
||||
description="Conditional Task 2",
|
||||
expected_output="Output 2",
|
||||
agent=agent2,
|
||||
condition=condition1_mock,
|
||||
)
|
||||
|
||||
# Second conditional task should check task1's output, not task2's
|
||||
condition2_mock = MagicMock()
|
||||
task3 = ConditionalTask(
|
||||
description="Conditional Task 3",
|
||||
expected_output="Output 3",
|
||||
agent=agent3,
|
||||
condition=condition2_mock,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent1, agent2, agent3],
|
||||
tasks=[task1, task2, task3],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
with patch.object(crew, '_find_task_index', return_value=0):
|
||||
task1_output = TaskOutput(
|
||||
description="Task 1",
|
||||
raw="Task 1 output",
|
||||
agent=agent1.role,
|
||||
output_format=OutputFormat.RAW,
|
||||
)
|
||||
|
||||
condition1_mock.return_value = True # Task should execute
|
||||
result1 = crew._handle_conditional_task(
|
||||
task=task2,
|
||||
task_outputs=[task1_output],
|
||||
futures=[],
|
||||
task_index=1,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output
|
||||
condition1_mock.assert_called_once()
|
||||
args1 = condition1_mock.call_args[0][0]
|
||||
assert args1.raw == "Task 1 output"
|
||||
assert result1 is None # Task should execute, so no skipped output
|
||||
|
||||
condition1_mock.reset_mock()
|
||||
|
||||
task2_output = TaskOutput(
|
||||
description="Conditional Task 2",
|
||||
raw="Task 2 output",
|
||||
agent=agent2.role,
|
||||
output_format=OutputFormat.RAW,
|
||||
)
|
||||
|
||||
condition2_mock.return_value = True # Task should execute
|
||||
result2 = crew._handle_conditional_task(
|
||||
task=task3,
|
||||
task_outputs=[task1_output, task2_output],
|
||||
futures=[],
|
||||
task_index=2,
|
||||
was_replayed=False,
|
||||
)
|
||||
|
||||
# Verify the condition was called with task1's output, not task2's
|
||||
condition2_mock.assert_called_once()
|
||||
args2 = condition2_mock.call_args[0][0]
|
||||
assert args2.raw == "Task 1 output" # Should be task1's output
|
||||
assert args2.raw != "Task 2 output" # Should not be task2's output
|
||||
assert result2 is None # Task should execute, so no skipped output
|
||||
83
tests/tools/test_human_tool.py
Normal file
83
tests/tools/test_human_tool.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Test HumanTool functionality."""
|
||||
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from crewai.tools import HumanTool
|
||||
|
||||
def test_human_tool_basic():
|
||||
"""Test basic HumanTool creation and attributes."""
|
||||
tool = HumanTool()
|
||||
assert tool.name == "human"
|
||||
assert "ask user to enter input" in tool.description.lower()
|
||||
assert not tool.result_as_answer
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_human_tool_with_langgraph_interrupt():
|
||||
"""Test HumanTool with LangGraph interrupt handling."""
|
||||
tool = HumanTool()
|
||||
|
||||
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||
mock_interrupt.return_value = {"data": "test response"}
|
||||
result = tool._run("test query")
|
||||
assert result == "test response"
|
||||
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
|
||||
|
||||
|
||||
def test_human_tool_timeout():
|
||||
"""Test HumanTool timeout handling."""
|
||||
tool = HumanTool()
|
||||
timeout = 30.0
|
||||
|
||||
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||
mock_interrupt.return_value = {"data": "test response"}
|
||||
result = tool._run("test query", timeout=timeout)
|
||||
assert result == "test response"
|
||||
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
|
||||
|
||||
|
||||
def test_human_tool_invalid_input():
|
||||
"""Test HumanTool input validation."""
|
||||
tool = HumanTool()
|
||||
|
||||
with pytest.raises(ValueError, match="Query must be a non-empty string"):
|
||||
tool._run("")
|
||||
|
||||
with pytest.raises(ValueError, match="Query must be a non-empty string"):
|
||||
tool._run(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_human_tool_async():
|
||||
"""Test async HumanTool functionality."""
|
||||
tool = HumanTool()
|
||||
|
||||
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||
mock_interrupt.return_value = {"data": "test response"}
|
||||
result = await tool._arun("test query")
|
||||
assert result == "test response"
|
||||
mock_interrupt.assert_called_with({"query": "test query", "timeout": None})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_human_tool_async_timeout():
|
||||
"""Test async HumanTool timeout handling."""
|
||||
tool = HumanTool()
|
||||
timeout = 30.0
|
||||
|
||||
with patch('langgraph.prebuilt.state_graphs.interrupt') as mock_interrupt:
|
||||
mock_interrupt.return_value = {"data": "test response"}
|
||||
result = await tool._arun("test query", timeout=timeout)
|
||||
assert result == "test response"
|
||||
mock_interrupt.assert_called_with({"query": "test query", "timeout": timeout})
|
||||
|
||||
|
||||
def test_human_tool_without_langgraph():
|
||||
"""Test HumanTool behavior when LangGraph is not installed."""
|
||||
tool = HumanTool()
|
||||
|
||||
with patch.dict('sys.modules', {'langgraph': None}):
|
||||
with pytest.raises(ImportError) as exc_info:
|
||||
tool._run("test query")
|
||||
assert "LangGraph is required" in str(exc_info.value)
|
||||
assert "pip install langgraph" in str(exc_info.value)
|
||||
@@ -1,12 +1,13 @@
|
||||
import json
|
||||
import random
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import Agent, Task
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.tool_calling import ToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
|
||||
|
||||
@@ -85,6 +86,36 @@ def test_random_number_tool_schema():
|
||||
)
|
||||
|
||||
|
||||
def test_tool_usage_interrupt_handling():
|
||||
"""Test that tool usage properly propagates LangGraph interrupts."""
|
||||
class InterruptingTool(BaseTool):
|
||||
name: str = "interrupt_test"
|
||||
description: str = "A tool that raises LangGraph interrupts"
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
raise type('Interrupt', (Exception,), {})("test interrupt")
|
||||
|
||||
tool = InterruptingTool()
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=MagicMock(),
|
||||
tools=[tool],
|
||||
original_tools=[tool],
|
||||
tools_description="Sample tool for testing",
|
||||
tools_names="interrupt_test",
|
||||
task=MagicMock(),
|
||||
function_calling_llm=MagicMock(),
|
||||
agent=MagicMock(),
|
||||
action=MagicMock(),
|
||||
)
|
||||
|
||||
# Test that interrupt is propagated
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
tool_usage.use(
|
||||
ToolCalling(tool_name="interrupt_test", arguments={"query": "test"}, log="test"),
|
||||
"test"
|
||||
)
|
||||
assert "test interrupt" in str(exc_info.value)
|
||||
|
||||
def test_tool_usage_render():
|
||||
tool = RandomNumberTool()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user