fix: Fix async context manager in MCP test stubs

- Change fake_streamablehttp_client from async def to regular def
- async with expects an object with __aenter__/__aexit__, not a coroutine
- Add __path__ to make mcp modules look like packages
- Add last_kwargs tracking for header assertions
- Add proper assertions to verify headers are passed/not passed
- This fixes TypeError: 'coroutine' object does not support async context manager protocol

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-10-26 10:11:39 +00:00
parent 9dfad32efe
commit 88c0950a6f

View File

@@ -19,6 +19,11 @@ def stub_mcp_modules(monkeypatch):
mcp_client = types.ModuleType("mcp.client") mcp_client = types.ModuleType("mcp.client")
mcp_streamable_http = types.ModuleType("mcp.client.streamable_http") mcp_streamable_http = types.ModuleType("mcp.client.streamable_http")
mcp.__path__ = []
mcp_client.__path__ = []
mcp.client = mcp_client
mcp_client.streamable_http = mcp_streamable_http
class MockClientSession: class MockClientSession:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.initialize = AsyncMock() self.initialize = AsyncMock()
@@ -28,27 +33,36 @@ def stub_mcp_modules(monkeypatch):
async def __aenter__(self): async def __aenter__(self):
return self return self
async def __aexit__(self, *exc): async def __aexit__(self, exc_type, exc, tb):
pass return False
mcp.ClientSession = MockClientSession mcp.ClientSession = MockClientSession
async def fake_streamablehttp_client(*args, **kwargs): last_kwargs = {}
"""Mock streamablehttp_client context manager."""
def fake_streamablehttp_client(*args, **kwargs):
"""Mock streamablehttp_client context manager (NOT async def)."""
last_kwargs.clear()
last_kwargs.update(kwargs)
class MockContextManager: class MockContextManager:
async def __aenter__(self): async def __aenter__(self):
return (AsyncMock(), AsyncMock(), AsyncMock()) return (AsyncMock(), AsyncMock(), AsyncMock())
async def __aexit__(self, *exc): async def __aexit__(self, exc_type, exc, tb):
pass return False
return MockContextManager() return MockContextManager()
fake_streamablehttp_client.last_kwargs = last_kwargs
mcp_streamable_http.streamablehttp_client = fake_streamablehttp_client mcp_streamable_http.streamablehttp_client = fake_streamablehttp_client
monkeypatch.setitem(sys.modules, "mcp", mcp) monkeypatch.setitem(sys.modules, "mcp", mcp)
monkeypatch.setitem(sys.modules, "mcp.client", mcp_client) monkeypatch.setitem(sys.modules, "mcp.client", mcp_client)
monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", mcp_streamable_http) monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", mcp_streamable_http)
return mcp_streamable_http
@pytest.fixture @pytest.fixture
@@ -265,7 +279,7 @@ class TestMCPToolWrapperHeaders:
assert wrapper.mcp_server_params["headers"] == headers assert wrapper.mcp_server_params["headers"] == headers
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_headers_passed_to_transport(self): async def test_headers_passed_to_transport(self, stub_mcp_modules):
"""Test that headers are passed to streamablehttp_client.""" """Test that headers are passed to streamablehttp_client."""
from mcp import ClientSession from mcp import ClientSession
@@ -294,11 +308,15 @@ class TestMCPToolWrapperHeaders:
try: try:
result = await wrapper._execute_tool(test_arg="test_value") result = await wrapper._execute_tool(test_arg="test_value")
assert result == "Test result" assert result == "Test result"
# Verify headers were passed to streamablehttp_client
assert "headers" in stub_mcp_modules.streamablehttp_client.last_kwargs
assert stub_mcp_modules.streamablehttp_client.last_kwargs["headers"] == headers
finally: finally:
ClientSession.__init__ = original_init ClientSession.__init__ = original_init
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_headers_when_not_configured(self): async def test_no_headers_when_not_configured(self, stub_mcp_modules):
"""Test that headers are not passed when not configured.""" """Test that headers are not passed when not configured."""
from mcp import ClientSession from mcp import ClientSession
@@ -322,6 +340,9 @@ class TestMCPToolWrapperHeaders:
try: try:
result = await wrapper._execute_tool(test_arg="test_value") result = await wrapper._execute_tool(test_arg="test_value")
assert result == "Test result" assert result == "Test result"
# Verify headers were NOT passed to streamablehttp_client
assert "headers" not in stub_mcp_modules.streamablehttp_client.last_kwargs
finally: finally:
ClientSession.__init__ = original_init ClientSession.__init__ = original_init