Jaderberg M, Simonyan K, Zisserman A, et al. Spatial transformer networks[J]. 2015:2017-2025.
虽然CNN的效果很好,但是仍然缺乏对数据的空间不变能力,从而限制了计算和参数的效率。因此,论文提出Spatial Transformer Network (STN)。
STN
在网络中对数据显式地进行空间操作(平移、旋转、缩放、裁剪、扭曲)。由于该操作可微,因此模型能够end to end训练。
根据输入数据,动态生成空间操作参数Θ。
网络参数直接通过loss回传进行学习。可直接添加到神经网络模型中,整个训练不需额外的监督信息加入。
空间操作后的数据是与后续特定任务高度相关的。另一方面,变换后的低分辨率数据比原始数据的计算效率更高。
通过对数据进行操作实现不变性,而不是对特征提取器(卷积核)。
适用的任务
classification
co-localization
spatial attention
1.Spatial Transformers
STN包含3部分(Figure 2)
localization network.
grid generator.
sampler.
Localization Network
输入U(h, w, c)
输出空间变换参数Θ
网络可以是任何形式,如FCN、CNN等。仿射变换Θ的参数为6,投影变换参数为8,以及thin plate spline (TPS). 模型对最后一层的weight矩阵初始化为0,bias初始化为[[1, 0, 0], [0, 1, 0]](仿射变换),即全等变换。
Parameterised Sampling Grid
首先根据采样网格大小(超参数)生成标准网格(t; x,y∈(-1, 1); (h, w, 2)).
利用空间变换参数Θ对其进行变换操作,生成采样网格(s; x,y∈(-1, 1); (h, w, 2)).
Differentiable Image Sampling
通用的采样公式可写为
k为通用采样kernel; x, m, y, n为坐标点。Φ为kernel的参数。
对于整数采样kernel,公式简化为
取x+0.5下界整数,δ函数为Kronecker delta函数
对于双线性采样kernel,公式简化为
该公式可导
Spatial Transformer Networks
由于Θ显式地编码了变换,因此也可将Θ传入后续的网络,而非变换后的特征图(或图片)。
可用STN对特征图进行上采样或下采样。但是,用固定的、小空间支持的采样kernel(双线性kernel)进行下采样会造成影响。
STN可级联或并行在网络中。
2.Experiments
Distorted MNIST
数据集distorted方式分为
R 旋转,±90°之间。
RTS 旋转+缩放+平移
P 投影
E 弹性形变(破坏性,不可逆)
所有模型都具有相同数量参数,分别使用3类变换操作:仿射变换(Aff)、投影变换(Proj)、薄板样条变换(TPS)。实验发现TPS最有效。
MNIST Addition
输入两张数字图片(h,w,2),输出数字的和。
Street View House Numbers
每张图片有1~5个数字。因此,模型采用级联STN,并使用5个独立的softmax分类器,每个分类器包含一个空字符。
Fine-Grained Classification
CUB-200-2011数据集,模型采用并行STN结构。
Co-localization
使用半监督学习来定位图像中的物体。基于正确定位对象A与正确定位对象B之间的距离,比A与随机定位crop小的假设,构造hinge loss
T表示crop,e为编码函数,α为margin,实验设置为1。数据集的构建操作为:将28*28的数字图片放在84*84背景中,并将从训练集中采样得到的16个随机6*6 crop放入背景中。当预测定位与ground-truth的交集大于0.5时,定义为预测正确。
Higher Dimensionnal Transformer
模型使用3D仿射变换和3D双线性插值操作。
另一种处理方法是:将3D空间投影到2D空间,例如
领取专属 10元无门槛券
私享最新 技术干货