本文分享 NeurIPS 2021 论文『Efficient Training of Visual Transformers with Small-Size Datasets』,由特伦托大学&腾讯联合提出新的损失函数,复现简单,可显著提高Transformer在小规模数据集上的性能,最高涨45%的精度!
详细信息如下:
导言:
视觉Transformer(VT) 正在成为卷积网络 (CNN) 结构范式的一种替代方案。与CNN不同,VT可以捕获图像元素之间的全局关系,并且它具有更强的表示能力。然而,由于缺乏卷积的归纳偏置,这些VT模型比普通CNN更需要数据,因为VT需要从大量的数据中学习这类信息。
在本文中,作者对不同的VT进行了实验分析,比较了它们在小训练集中的鲁棒性,结果表明,尽管在ImageNet上训练时具有相当的精度,但它们在较小数据集上的性能会有很大的不同。因此,作者提出了一种自监督的任务,该任务可以从图像中提取其他信息,而计算开销却可以忽略不计。
此任务鼓励VT学习图像中的空间关系,并在训练数据不足时使VT训练更加鲁棒。本文的自监督任务可以与监督任务联合使用,并且它不依赖于特定的网络结构,因此它可以很容易地插入现有的VT中。基于不同的VT结构和数据集进行广泛的评估,作者证明了本文的方法可以提高 VT的准确率。
01 Motivation
视觉Transformers(VTs)是计算机视觉中最近兴起的一种结构,可以替代标准的卷积神经网络(CNN),并且已经被应用于许多任务,如图像分类,目标检测,分割,跟踪和图像生成。视觉Transformer中的开创性工作是ViT,它使用非重叠的patch来分割图像,每个patch进行线性投影,从而获得“token”。之后,所有的token都由一系列的多头注意和前馈层处理,类似于NLP Transformers中token的处理方式。
VTs的明显优势是网络可以使用注意力层来模拟token之间的全局关系,这是相对于CNN的主要区别。然而,基于视觉信息的局部性、平移不变性和层次性,VTs表征能力的提高是以缺乏CNN中的归纳偏置为代价的。因此,VT需要大量的数据来进行训练。例如,ViT使用JFT-300M进行训练,JFT-300M是由3.03亿张高分辨率图像组成的巨型数据集。当只在ImageNet-1K(大约130万样本)上训练时,ViT的性能比具有类似参数量的Resnet差。这可能是由于ViT需要使用比CNN更多的样本来学习视觉数据的一些局部属性,而CNN则将这些性质嵌入到了它的结构设计中。
为了缓解这个问题,第二代的VT结构被提出,这些网络通常是将卷积层与注意层混合在一起,从而为VT提供了局部归纳偏置。这些混合结构同时具备两种范例的优势: 注意层对全局依赖关系进行建模,而卷积操作可以强调图像内容的局部特征。大多数工作中的实验结果表明,这种第二代的VTs可以在ImageNet上进行训练,其性能优于此数据集上类似大小的ResNet。然而,在中小型数据集上进行训练时,这些网络的结果仍不清楚。
在本文中,作者通过在中小型数据集上,从头开始训练它们或对它们进行微调来相互比较不同第二代VT结构的区别。实验结果表明,尽管它们的ImageNet性能基本上彼此相当,但它们在较小数据集上的分类精度却有很大的不同。此外,作者提出使用其他自监督的任务和相应的损失函数,以加快在小型训练集或更少epoch数约束下的训练。具体而言,该任务是学习输出token嵌入之间的空间关系。
由于这个任务是自我监督的,因此该任务的密集相对定位损失函数 () 不需要额外的标注,并且可以将其与标准交叉熵联合使用,作为VT训练的正则化。非常简单且容易复现,它可以在很大程度上提高了VTs的准确性,尤其是当VT在小数据集上从头开始训练,或者在相对于预训练ImageNet数据集具有较大域偏移的数据集上进行微调时。在实验中,基于不同的训练场景、不同的训练数据量和三种不同的VTs,对baseline的结果都有提升,有时会提高数十个点 (最多45个点) 的准确率。
02 方法
在本文中,作者重点关注第二代VT,它们是混合了自注意力层和卷积运算的混合结构。这些网络将图像作为输入,该图像首先被分割成K×K个patch,然后用线性投影将patch投到样嵌入空间中,得到一组K × K个的输入token。在VT中,自注意力层和卷积结构能够对这些token进行全局和局部信息的建模。其中使用步长大于1的卷积或池化操作,可以降低初始K × K的token特征的分辨率,从而模拟CNN的层次结构。
对于用于分类的特征,有的方法采用了额外的class token,另外也有一些方法采用了将所有grid的特征进行平均池化来获得整张图片的表示,从而来进行分类。最后,在这些用于分类的特征上进行MLP来获得目标类集合的后验分布,并使用交叉熵损失函数来进行训练VT。
本文提出的正则化损失的目标是鼓励VT在不使用额外的人工标注的情况下学习空间信息。作者通过对每个图像的多个嵌入对进行密集采样并要求网络预测它们的相对位置来实现空间信息的学习。具体实现上,给定图像,将VT最后输出的grid特征表示为,其中,是嵌入空间的维数。对于每个,作者随机采样多对嵌入,并且对于每个采样对,计算2D归一化平移偏移量 ,计算方式如下:
将选定的嵌入向量和进行concat,然后输入到一个MLP中,该MLP具有两个隐藏层和两个输出神经元(如上图所示),用来预测网格上位置 , 和位置,之间的相对距离,即
。给定一个Batch 的n个图像,本文提出的密集相对定位损失(dense relative localization loss)为:
被添加到每个原始VT 的标准交叉熵损失()中。最后总的损失为: 。在T2T和CvT的所有实验中使用 λ = 0.1,在Swin中使用 λ = 0.5。
除了最简单的dense relative localization loss,作者还提出该损失函数的一些变体。
变体1:
上面的损失基于横向和纵向相对距离的绝对值,变体1中还考虑正负的方向,如下所示:
用代替原始公式中的,其他部分保持不变。变体1 的损失函数记为。
变体2:
在变体2中,作者将回归任务转换成了分类任务,并将L1损失修改为交叉熵损失。计算上,目标的偏移计算如下:
然后,我们将,,中的个离散元素与相应的类相关联。此外还需要将MLP中两个输出神经元替换为两组输出神经元,每组神经元输出代表个类。Softmax分别应用于每组个神经元,的输出由上的两个后验分布组成:。该变体的损失函数为:
其中表示的第个元素
变体3:
上述公式中的交叉熵损失,将看做是一个无序的“类别”集合。这意味着 (和) 中的预测误差与相对于ground truth (和) 是 “距离” 无关的。为了缓解这个问题,作者提出的第三种变体,在和上施加了高斯先验,并最小化的期望值与ground truth (分别为和) 之间的归一化平方距离。在实现上,令均值,方差,损失函数可以表示为:
其中,和用于方差正则化。
变体4:
最后一个变体是基于 “非常密集” 的定位损失,即针对VT的每个Transformer块计算。具体地说,设是由VT的第l个块输出的的token嵌入,L为Transformer块的总数。然后,新的损失函数为:
其中,和分别是在第个块中随机采样对计算的目标偏移量和预测偏移量。不同的Transformer块,采用不同的MLP来预测。
03 实验
作者在不同的数据集上进行了实验,上表为本文进行实验数据集的具体信息。
作者在ImageNet-100上对不同损失函数变体进行了实验,可以看出,除了之外,其他损失函数都能提高性能。
上表为CIFAR-10上的结果,m为采样数量,可以看出,m为64是效果最好。
上表显示了不同VT模型,不同epoch数下的实验结果,可以看出本文的方法在不同模型和不同epoch数下,都能提升模型性能。
上表展示了不同模型在不同数据集上的结果,可以看出,加上本文方法之后,性能都有提升,最高提升了45个点。
上表展示了ImageNet上预训练的模型,在不同数据集上进行微调的结果,可以看出,本文的方法依旧能够提升模型性能。
04 总结
在本文中,作者对不同的VT进行了实验分析,结果表明,当用中小型数据集从头开始训练时,这些模型的性能差别很大。为了缓解这一问题,作者提出了一项自监督任务,用来进行VT的训练。该定位任务是对最后层的token嵌入对进行密集地随机采样,并且它鼓励VT学习空间信息。
在实验中,作者使用了11个数据集、不同的训练设置和3个VT模型,本文的密集定位损失都能够提高相应的baseline精度。这表明本文提出的任务和损失函数,可以提高VT的性能,特别是在数据/训练时间有限的情况中。此外,它还为研究其他形式的自监督/多任务学习铺平了道路,可以帮助VT更好的训练,而不需要使用大量标注数据集。