mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
fix: Fix async context manager in MCP test stubs
- Change fake_streamablehttp_client from async def to regular def - async with expects an object with __aenter__/__aexit__, not a coroutine - Add __path__ to make mcp modules look like packages - Add last_kwargs tracking for header assertions - Add proper assertions to verify headers are passed/not passed - This fixes TypeError: 'coroutine' object does not support async context manager protocol Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -19,6 +19,11 @@ def stub_mcp_modules(monkeypatch):
|
||||
mcp_client = types.ModuleType("mcp.client")
|
||||
mcp_streamable_http = types.ModuleType("mcp.client.streamable_http")
|
||||
|
||||
mcp.__path__ = []
|
||||
mcp_client.__path__ = []
|
||||
mcp.client = mcp_client
|
||||
mcp_client.streamable_http = mcp_streamable_http
|
||||
|
||||
class MockClientSession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.initialize = AsyncMock()
|
||||
@@ -28,27 +33,36 @@ def stub_mcp_modules(monkeypatch):
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
pass
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
mcp.ClientSession = MockClientSession
|
||||
|
||||
async def fake_streamablehttp_client(*args, **kwargs):
|
||||
"""Mock streamablehttp_client context manager."""
|
||||
last_kwargs = {}
|
||||
|
||||
def fake_streamablehttp_client(*args, **kwargs):
|
||||
"""Mock streamablehttp_client context manager (NOT async def)."""
|
||||
last_kwargs.clear()
|
||||
last_kwargs.update(kwargs)
|
||||
|
||||
class MockContextManager:
|
||||
async def __aenter__(self):
|
||||
return (AsyncMock(), AsyncMock(), AsyncMock())
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
pass
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
return MockContextManager()
|
||||
|
||||
fake_streamablehttp_client.last_kwargs = last_kwargs
|
||||
|
||||
mcp_streamable_http.streamablehttp_client = fake_streamablehttp_client
|
||||
|
||||
monkeypatch.setitem(sys.modules, "mcp", mcp)
|
||||
monkeypatch.setitem(sys.modules, "mcp.client", mcp_client)
|
||||
monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", mcp_streamable_http)
|
||||
|
||||
return mcp_streamable_http
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -265,7 +279,7 @@ class TestMCPToolWrapperHeaders:
|
||||
assert wrapper.mcp_server_params["headers"] == headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_headers_passed_to_transport(self):
|
||||
async def test_headers_passed_to_transport(self, stub_mcp_modules):
|
||||
"""Test that headers are passed to streamablehttp_client."""
|
||||
from mcp import ClientSession
|
||||
|
||||
@@ -294,11 +308,15 @@ class TestMCPToolWrapperHeaders:
|
||||
try:
|
||||
result = await wrapper._execute_tool(test_arg="test_value")
|
||||
assert result == "Test result"
|
||||
|
||||
# Verify headers were passed to streamablehttp_client
|
||||
assert "headers" in stub_mcp_modules.streamablehttp_client.last_kwargs
|
||||
assert stub_mcp_modules.streamablehttp_client.last_kwargs["headers"] == headers
|
||||
finally:
|
||||
ClientSession.__init__ = original_init
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_headers_when_not_configured(self):
|
||||
async def test_no_headers_when_not_configured(self, stub_mcp_modules):
|
||||
"""Test that headers are not passed when not configured."""
|
||||
from mcp import ClientSession
|
||||
|
||||
@@ -322,6 +340,9 @@ class TestMCPToolWrapperHeaders:
|
||||
try:
|
||||
result = await wrapper._execute_tool(test_arg="test_value")
|
||||
assert result == "Test result"
|
||||
|
||||
# Verify headers were NOT passed to streamablehttp_client
|
||||
assert "headers" not in stub_mcp_modules.streamablehttp_client.last_kwargs
|
||||
finally:
|
||||
ClientSession.__init__ = original_init
|
||||
|
||||
|
||||
Reference in New Issue
Block a user