作者:Manjunath Bhat
编译:ronghuaiyang
对空间变换网络STN做了一个简单的原理性的介绍。
作为谷歌Summer of Code项目的一部分,我要实现的第一个模型是空间变压器网络。空间变压器网络(STN)是一个可学习的模块,可以放置在卷积神经网络(CNN)中,有效地增加空间不变性。空间不变性是指模型对图像的空间变换如旋转、平移和缩放不变性。不变性是指即使输入被变换或轻微修改,模型也能识别和识别特征的能力。空间变压器可以放置到CNN中,以完成各种任务。图像分类就是一个例子。假设任务是对手写数字进行分类,每个样本中数字的位置、大小和方向变化显著。一个空间转换器将提取、变换和缩放样本中感兴趣的区域。现在CNN可以完成分类的任务。
空间变压器网络由3个主要组成部分组成:
(i) 定位网络:该网络以一个batch的图像的四维张量表示(宽度x高度x通道x Batch_Size)作为输入。它是一个简单的神经网络,有几个卷积层和几个dense层。将变换参数预测为输出。这些参数决定了输入必须旋转的角度、要完成的平移量以及聚焦于输入特征图中感兴趣的区域所需的比例因子。
(ii) 采样网格生成器:对batch中每幅图像使用定位网络预测的变换参数,其形式为大小为2×3的仿射变换矩阵。仿射变换是一种保留点、直线和平面的变换。经过仿射变换后,平行线保持平行。旋转、缩放和平移都是仿射变换。
这里,T是这个仿射变换,A是表示仿射变换的矩阵。θ11, θ12, θ21, θ22被用来确定图像旋转的角度。θ13, θ23分别确定了图像沿宽度和高度的平移量。因此,我们得到了一个转换索引的采样网格。
(iii) 变换后索引上的双线性插值:现在图像的索引和坐标轴已经进行了仿射变换。它的像素移动了。例如,一个点(1,1)在轴逆时针旋转45度后变成(√2,0),因此要找到变换点处的像素值,我们需要使用四个最接近的像素值进行双线性插值。
为了找到点(x, y)上的像素值,我们取4个最近的点,如上图所示。其中,floor(x)表示最大整数函数,ceil(x)表示ceiling函数。线性插值必须在x和y两个方向上完成。因此,这个函数返回完全转换后的图像,并在转换索引处使用适当的像素值。
纯Julia实现空间变压器网络的代码可以在这里找到:https://github.com/thebhatman/Spatial-Transformer-Network/blob/master/src/stn.jl。我在一些图像上测试了我的空间转换器模块的功能。下面是转换函数输出的一些示例图像。左边的图像是转换器模块的输入,右边的图像是输出。
从上面的例子可以清楚地看出,空间转换器模块能够执行任何类型的仿射变换。在实现过程中,我花了很多时间来理解数组的reshape、permutedims和concatenation是如何工作的,因为当我使用这些函数时,很难调试像素和索引是如何移动的。在STN实现过程中,调试插值和图像索引是最耗费时间和最令人沮丧的部分。
现在,我计划使用一个CNN来训练这个空间转换器模块,以便对一个杂乱和扭曲的MNIST数据集进行手写数字分类。空间变压器将能够增加CNN的空间不变性,因此期望即使在数字被平移、旋转或缩放时也能给出良好的分类结果。
英文原文:https://medium.com/@manjunathbhat9920/spatial-transformer-network-82666f184299