作者:一元,四品炼丹师
TabNet: Attentive Interpretable Tabular Learning(ArXiv2020)
01
背景
本文提出了一种高性能、可解释的规范深度表格数据学习结构TabNet。号称吊锤XGBoost和LightGBM等GBDT类模型。来吧,开学!
TabNet使用sequential的attention来选择在每个决策步骤中要推理的特征,使得学习被用于最显著的特征,从而实现可解释性和更有效的学习。我们证明了TabNet在广泛的非性能饱和表格数据集上优于其他变体,并产生了可解释的特征属性和对其全局行为的洞察。
最后,我们展示了表格数据的自监督学习,在未标记数据丰富的情况下显著提高了效果。
1. 决策树类模型在诸多的表格型问题中仍然具有非常大的优势:
2. DNN的优势:
3. TabNet:
02
TabNet
类似于DTs的DNN building blocks
TabNET的框架
我们使用所有的原始数值特征并且将类别特征转化为可以训练的embedding,我们并不考虑全局特征normalization。
在每一轮我们将D维度的特征传入,其中是batch size, TabNet的编码是基于序列化的多步处理, 有个决策过程。在第步我们输入第步的处理信息来决定使用哪些特征,并且输出处理过的特征表示来集成到整体的决策。
特征选择
我们使用可学习的mask, 用于显著特征的soft选择,通过最多的显著特征的稀疏选择,决策步的学习能力在不相关的上面不被浪费,从而使模型更具参数效率。masking是可乘的,,此处我们使用attentive transformer来获得使用在前面步骤中处理过的特征的masks,.
Sparsemax规范化通过将欧几里得投影映射到概率simplex上鼓励稀疏性,观察到概率simplex在性能上更优越,并与稀疏特征选择的目标一致,以便于解释。注意: , 是一个可以训练的函数。
是先验的scale项,表示一个特殊的特征之前被使用的多少,,其中是缩放参数。
为了控制选择特征的稀疏性,此处加入sparsity的正则来控制数值稳定性,
其中对于数值稳定性是一个很小的书,我们再最终的loss上加入稀疏的正则,对应的参数为.
特征处理
我们使用一个特征transformer来处理过滤的特征,然后拆分决策步骤输出和后续步骤信息,,其中, ,对于具有高容量的参数有效且鲁棒的学习,特征变换器应该包括在所有决策步骤之间共享的层(因为在不同的决策步骤之间输入相同的特征)以及决策步骤相关的层。上图展示了作为两个共享层和两个决策步骤相关层的级联的实现。
每个FC层后面是BN和gated线性单元(GLU)非线性,最终通过归一化连接到归一化残差连接。此处我们通过的正则来保证网络的方差以稳定学习。
为了快速的训练,此处我们使用带有BN的大的batch size,因此,除了应用到输入特征的,我们使用ghost BN形式,使用一个virtual batchsize 和momentum ,对于输入特征,我们观测到low-variance平均的好处,因此可以避免ghost BN,最终我们通过decision-tree形式的aggregation,我们构建整体的决策embedding, ,再使用线性mapping, 得到最终的输出。
解释性
此处我们可以使用特征选择的mask来捕捉在每一步的选择的特征,如果:
如果是一个线性函数,的稀疏应该对应的二者重要性,尽管每次决策步使用一个非线性处理,他们的输出是以线性的方式组合,我们的目的是量化一个总体特征的重要性,除了分析每一步。组合不同步骤的Mask需要一个系数来衡量决策中每个步骤的相对重要性,我们提出:
直觉上,如果,那么在第个决策步的所有特征就应当对整体的决策没有任何帮助。当它的值增长的时候,它在整体线性的组合上会更为重要,在每次决策步的时候对决策mask进行缩放,,我们对特征重要性mask进行特征的集成, .
表格自监督学习
我们提出了一个解码器架构来从TabNet编码的表示中重建表格特征。解码器由特征变换器组成,每个判决步骤后面是FC层。将输出相加得到重构特征。我们提出了一个从其他特征列中预测缺失特征列的任务。考虑一个二进制掩码,
我们在编码器中初始化, 这么做模型只重点关注已知的特征,解码器的最后一层FC层和进行相乘输出未知的特征,我们考虑在自监督阶段的重构损失,
使用真实值的标准偏差进行Normalization是有帮助的,因为特征可能有不同的ranges,我们在每次迭代时以概率从伯努利分布中独立采样;
03
实验
04
小结
本文我们提出了TabNet,一种新的用于表格学习的深度学习体系结构。TabNet使用一种顺序attention机制来选择语义上有意义的特征子集,以便在每个决策步骤中进行处理。基于实例的特征选择能够有效地进行学习,因为模型容量被充分地用于最显著的特征,并且通过选择模板的可视化产生更具解释性的决策。我们证明了TabNet在不同领域的表格数据集上的性能优于以前的工作。最后,我们展示了无监督预训练对于快速适应和提高模型的效果。
05
参考文献