diff --git a/lib/crewai/tests/tools/test_mcp_tool_wrapper.py b/lib/crewai/tests/tools/test_mcp_tool_wrapper.py index d71c11d92..440f3ed37 100644 --- a/lib/crewai/tests/tools/test_mcp_tool_wrapper.py +++ b/lib/crewai/tests/tools/test_mcp_tool_wrapper.py @@ -19,6 +19,11 @@ def stub_mcp_modules(monkeypatch): mcp_client = types.ModuleType("mcp.client") mcp_streamable_http = types.ModuleType("mcp.client.streamable_http") + mcp.__path__ = [] + mcp_client.__path__ = [] + mcp.client = mcp_client + mcp_client.streamable_http = mcp_streamable_http + class MockClientSession: def __init__(self, *args, **kwargs): self.initialize = AsyncMock() @@ -28,27 +33,36 @@ def stub_mcp_modules(monkeypatch): async def __aenter__(self): return self - async def __aexit__(self, *exc): - pass + async def __aexit__(self, exc_type, exc, tb): + return False mcp.ClientSession = MockClientSession - async def fake_streamablehttp_client(*args, **kwargs): - """Mock streamablehttp_client context manager.""" + last_kwargs = {} + + def fake_streamablehttp_client(*args, **kwargs): + """Mock streamablehttp_client context manager (NOT async def).""" + last_kwargs.clear() + last_kwargs.update(kwargs) + class MockContextManager: async def __aenter__(self): return (AsyncMock(), AsyncMock(), AsyncMock()) - async def __aexit__(self, *exc): - pass + async def __aexit__(self, exc_type, exc, tb): + return False return MockContextManager() + fake_streamablehttp_client.last_kwargs = last_kwargs + mcp_streamable_http.streamablehttp_client = fake_streamablehttp_client monkeypatch.setitem(sys.modules, "mcp", mcp) monkeypatch.setitem(sys.modules, "mcp.client", mcp_client) monkeypatch.setitem(sys.modules, "mcp.client.streamable_http", mcp_streamable_http) + + return mcp_streamable_http @pytest.fixture @@ -265,7 +279,7 @@ class TestMCPToolWrapperHeaders: assert wrapper.mcp_server_params["headers"] == headers @pytest.mark.asyncio - async def test_headers_passed_to_transport(self): + async def test_headers_passed_to_transport(self, stub_mcp_modules): """Test that headers are passed to streamablehttp_client.""" from mcp import ClientSession @@ -294,11 +308,15 @@ class TestMCPToolWrapperHeaders: try: result = await wrapper._execute_tool(test_arg="test_value") assert result == "Test result" + + # Verify headers were passed to streamablehttp_client + assert "headers" in stub_mcp_modules.streamablehttp_client.last_kwargs + assert stub_mcp_modules.streamablehttp_client.last_kwargs["headers"] == headers finally: ClientSession.__init__ = original_init @pytest.mark.asyncio - async def test_no_headers_when_not_configured(self): + async def test_no_headers_when_not_configured(self, stub_mcp_modules): """Test that headers are not passed when not configured.""" from mcp import ClientSession @@ -322,6 +340,9 @@ class TestMCPToolWrapperHeaders: try: result = await wrapper._execute_tool(test_arg="test_value") assert result == "Test result" + + # Verify headers were NOT passed to streamablehttp_client + assert "headers" not in stub_mcp_modules.streamablehttp_client.last_kwargs finally: ClientSession.__init__ = original_init