Files
crewAI/lib/crewai/tests/tools/test_mcp_tool_wrapper.py
Devin AI 88c0950a6f fix: Fix async context manager in MCP test stubs
- Change fake_streamablehttp_client from async def to regular def
- async with expects an object with __aenter__/__aexit__, not a coroutine
- Add __path__ to make mcp modules look like packages
- Add last_kwargs tracking for header assertions
- Add proper assertions to verify headers are passed/not passed
- This fixes TypeError: 'coroutine' object does not support async context manager protocol

Co-Authored-By: João <joao@crewai.com>
2025-10-26 10:11:39 +00:00

393 lines
13 KiB
Python

"""Tests for MCPToolWrapper progress and headers support."""
import asyncio
import sys
import types
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.tool_usage_events import MCPToolProgressEvent
from crewai.tools.mcp_tool_wrapper import MCPToolWrapper
@pytest.fixture(autouse=True)
def stub_mcp_modules(monkeypatch):
"""Stub the mcp modules in sys.modules to avoid import errors in CI."""
mcp = types.ModuleType("mcp")
mcp_client = types.ModuleType("mcp.client")
mcp_streamable_http = types.ModuleType("mcp.client.streamable_http")
mcp.__path__ = []
mcp_client.__path__ = []
mcp.client = mcp_client
mcp_client.streamable_http = mcp_streamable_http
class MockClientSession:
def __init__(self, *args, **kwargs):
self.initialize = AsyncMock()
self.call_tool = AsyncMock()
self.on_progress = None
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
mcp.ClientSession = MockClientSession
last_kwargs = {}
def fake_streamablehttp_client(*args, **kwargs):
"""Mock streamablehttp_client context manager (NOT async def)."""
last_kwargs.clear()
last_kwargs.update(kwargs)
class MockContextManager:
async def __aenter__(self):
return (AsyncMock(), AsyncMock(), AsyncMock())
async def __aexit__(self, exc_type, exc, tb):
return False
return MockContextManager()
fake_streamablehttp_client.last_kwargs = last_kwargs
mcp_streamable_http.streamablehttp_client = fake_streamablehttp_client
monkeypatch.setitem(sys.modules, "mcp", mcp)
monkeypatch.setitem(sys.modules, "mcp.client", mcp_client)
monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", mcp_streamable_http)
return mcp_streamable_http
@pytest.fixture
def mock_mcp_session():
"""Create a mock MCP ClientSession."""
session = AsyncMock()
session.initialize = AsyncMock()
session.call_tool = AsyncMock()
return session
@pytest.fixture
def mock_streamable_client(mock_mcp_session):
"""Create a mock streamablehttp_client context manager."""
async def mock_client(*args, **kwargs):
read = AsyncMock()
write = AsyncMock()
close = AsyncMock()
class MockContextManager:
async def __aenter__(self):
return (read, write, close)
async def __aexit__(self, *args):
pass
return MockContextManager()
return mock_client
@pytest.fixture
def mock_agent():
"""Create a mock agent with id and role."""
agent = Mock()
agent.id = "test-agent-id"
agent.role = "Test Agent"
return agent
@pytest.fixture
def mock_task():
"""Create a mock task with id and description."""
task = Mock()
task.id = "test-task-id"
task.description = "Test Task Description"
task.name = None
return task
class TestMCPToolWrapperProgress:
"""Test suite for MCP tool wrapper progress notifications."""
def test_wrapper_initialization_with_progress_callback(self):
"""Test that MCPToolWrapper can be initialized with progress callback."""
callback = Mock()
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=callback,
)
assert wrapper._progress_callback == callback
assert wrapper.name == "test_server_test_tool"
def test_wrapper_initialization_without_progress_callback(self):
"""Test that MCPToolWrapper works without progress callback."""
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
)
assert wrapper._progress_callback is None
def test_wrapper_initialization_with_agent_and_task(self, mock_agent, mock_task):
"""Test that MCPToolWrapper can be initialized with agent and task context."""
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
agent=mock_agent,
task=mock_task,
)
assert wrapper._agent == mock_agent
assert wrapper._task == mock_task
@pytest.mark.asyncio
async def test_progress_handler_called_during_execution(self, mock_agent, mock_task, stub_mcp_modules):
"""Test that progress callback is invoked when MCP server sends progress."""
import sys
from mcp import ClientSession
progress_callback = Mock()
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=progress_callback,
agent=mock_agent,
task=mock_task,
)
# Set up the mock result on the stubbed ClientSession
mock_result = Mock()
mock_result.content = [Mock(text="Test result")]
original_init = ClientSession.__init__
def patched_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
self.call_tool = AsyncMock(return_value=mock_result)
ClientSession.__init__ = patched_init
try:
result = await wrapper._execute_tool(test_arg="test_value")
assert result == "Test result"
finally:
ClientSession.__init__ = original_init
@pytest.mark.asyncio
async def test_progress_event_emission(self, mock_agent, mock_task):
"""Test that MCPToolProgressEvent is emitted when progress is reported."""
events_received = []
def event_handler(source, event):
if isinstance(event, MCPToolProgressEvent):
events_received.append(event)
crewai_event_bus.register_handler(MCPToolProgressEvent, event_handler)
try:
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=Mock(),
agent=mock_agent,
task=mock_task,
)
wrapper._emit_progress_event(50.0, 100.0, "Processing...")
await asyncio.sleep(0.1)
assert len(events_received) == 1
event = events_received[0]
assert event.tool_name == "test_tool"
assert event.server_name == "test_server"
assert event.progress == 50.0
assert event.total == 100.0
assert event.message == "Processing..."
assert event.agent_id == "test-agent-id"
assert event.agent_role == "Test Agent"
assert event.task_id == "test-task-id"
assert event.task_name == "Test Task Description"
finally:
crewai_event_bus._sync_handlers.pop(MCPToolProgressEvent, None)
def test_progress_event_without_agent_context(self):
"""Test that progress events work without agent context."""
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=Mock(),
)
wrapper._emit_progress_event(25.0, None, "Starting...")
def test_progress_event_without_task_context(self, mock_agent):
"""Test that progress events work without task context."""
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=Mock(),
agent=mock_agent,
)
wrapper._emit_progress_event(75.0, 100.0, None)
class TestMCPToolWrapperHeaders:
"""Test suite for MCP tool wrapper headers support."""
def test_wrapper_initialization_with_headers(self):
"""Test that MCPToolWrapper accepts headers in server params."""
headers = {"Authorization": "Bearer token123", "X-Client-ID": "test-client"}
wrapper = MCPToolWrapper(
mcp_server_params={
"url": "https://example.com/mcp",
"headers": headers,
},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
)
assert wrapper.mcp_server_params["headers"] == headers
@pytest.mark.asyncio
async def test_headers_passed_to_transport(self, stub_mcp_modules):
"""Test that headers are passed to streamablehttp_client."""
from mcp import ClientSession
headers = {"Authorization": "Bearer token123"}
wrapper = MCPToolWrapper(
mcp_server_params={
"url": "https://example.com/mcp",
"headers": headers,
},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
)
mock_result = Mock()
mock_result.content = [Mock(text="Test result")]
original_init = ClientSession.__init__
def patched_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
self.call_tool = AsyncMock(return_value=mock_result)
ClientSession.__init__ = patched_init
try:
result = await wrapper._execute_tool(test_arg="test_value")
assert result == "Test result"
# Verify headers were passed to streamablehttp_client
assert "headers" in stub_mcp_modules.streamablehttp_client.last_kwargs
assert stub_mcp_modules.streamablehttp_client.last_kwargs["headers"] == headers
finally:
ClientSession.__init__ = original_init
@pytest.mark.asyncio
async def test_no_headers_when_not_configured(self, stub_mcp_modules):
"""Test that headers are not passed when not configured."""
from mcp import ClientSession
wrapper = MCPToolWrapper(
mcp_server_params={"url": "https://example.com/mcp"},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
)
mock_result = Mock()
mock_result.content = [Mock(text="Test result")]
original_init = ClientSession.__init__
def patched_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
self.call_tool = AsyncMock(return_value=mock_result)
ClientSession.__init__ = patched_init
try:
result = await wrapper._execute_tool(test_arg="test_value")
assert result == "Test result"
# Verify headers were NOT passed to streamablehttp_client
assert "headers" not in stub_mcp_modules.streamablehttp_client.last_kwargs
finally:
ClientSession.__init__ = original_init
class TestMCPToolWrapperIntegration:
"""Integration tests for MCP tool wrapper with progress and headers."""
@pytest.mark.asyncio
async def test_full_execution_with_progress_and_headers(self, mock_agent, mock_task):
"""Test complete execution flow with both progress and headers."""
from mcp import ClientSession
progress_calls = []
def progress_callback(progress, total, message):
progress_calls.append((progress, total, message))
headers = {"Authorization": "Bearer test-token"}
wrapper = MCPToolWrapper(
mcp_server_params={
"url": "https://example.com/mcp",
"headers": headers,
},
tool_name="test_tool",
tool_schema={"description": "Test tool"},
server_name="test_server",
progress_callback=progress_callback,
agent=mock_agent,
task=mock_task,
)
mock_result = Mock()
mock_result.content = [Mock(text="Test result")]
original_init = ClientSession.__init__
def patched_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
self.call_tool = AsyncMock(return_value=mock_result)
ClientSession.__init__ = patched_init
try:
result = await wrapper._execute_tool(test_arg="test_value")
assert result == "Test result"
finally:
ClientSession.__init__ = original_init