对于目前基于神经网络的序列模型,很重要的一个任务就是从序列模型中采样。比如解码时我们希望能产生多个不一样的结果,而传统的解码算法只能产生相似的结果。又比如训练时使用基于强化学习或者最小风险训练的方法需要从模型中随机采集多个不一样的样本来计算句子级的损失,而一般的确定性方法不能提供所需要的随机性。本文回顾了一系列常用的序列模型采样方法,包括基于蒙特卡洛的随机采样和随机束搜索,以及最近提出的基于Gumbel-Top-K的随机束搜索。表1展示了这三种方法各自的优缺点。
表1 不同采样方法对比
序列模型中的束搜索
在此之前,我们首先回顾一下束搜索。在序列模型中,束搜索通常被用来提升模型解码时的性能。默认的贪婪解码总是在每一步挑选一个当前分数最高的词来组成序列。相比起贪婪解码,束搜索每一步都挑选多个词来组成多个候选序列,最后挑选分数最高的序列作为最终输出。束搜索虽然增加了计算量,但是也显著提升了模型性能。图1是一个束大小为2的束搜索的例子:
图1 束搜索第一步
在解码第一步的时候,束搜索从句子开始符开始,根据模型的打分(是在给定前缀的情况下模型输出的下一词分布)来挑选词表中得分最高的前两个词he和I,并用he和I的得分和分别作为候选序列 he和 I的得分。
图2 计算束搜索第二步打分
在解码第二步的时候,根据模型的打分为已经生成部分内容的句子 he和 I各自挑选得分最高的前两个词,如 he会挑选hit和struck, I会挑选was和got,然后组成一共四个候选序列 he hit, he struck, I was和 I got,并分别计算他们的得分,比如 he hit的得分等于 he这个序列的得分加上hit的得分,如图2所示。最后保留这四个候选序列中得分最高的前两个序列,即 he hit和 I was,如图3所示。
图3 挑选束搜索第二步候选
以此类推,束搜索一直迭代到固定次数或者所有的候选序列都结束才停止。在这个例子中束搜索在第六步停止,产生了两个候选序列 he hit me with a pie和 he hit me with a tart,并挑选得分最高的 he hit me with a pie作为最终的结果,如图4所示。
图4 束搜索最终结果
序列模型中的随机采样
从序列模型中采集多个样本有两种经典的方法:基于蒙特卡洛的随机采样和基于蒙特卡洛的束搜索。
基于蒙特卡洛的随机采样
在序列模型中采样的最简单方法就是在贪婪搜索的基础上,在每一步挑选下一个词的时候不是根据它们相应的得分而是根据模型输出的下一个词分布来随机选取一个,这样重复到固定长度或者挑选到句子结束符时停止。这样我们获得了一个样本。如果需要采集多个样本,那么重复这个过程若干次便可得到多个样本。
基于蒙特卡洛的随机采样虽然简单,但是它面临着严重的效率问题。如果模型输出的下一个词分布熵很低,即对于个别词输出概率特别高,那么采集到的样本将有很大一部分重复,比如接近收敛时候的模型。因此为了采集到固定数目的不同样本,基于蒙特卡洛的随机采样可能需要远远大于所需样本数的采样次数,使得采样过程十分低效。
基于蒙特卡洛的随机束搜索
基于蒙特卡洛的随机束搜索在采集多个不同样本远比基于蒙特卡洛的随机采样高效。假设现在束大小为K,基于蒙特卡洛的随机束搜索在束搜索的基础上,把根据下一词的得分挑选前K个得分最高的词的操作替换成根据下一个词分布随机挑选K个不同词。因为每一步都挑选了不同的词,因此最终产生的K个候选序列都不会相同,从而达到了高效采集K个样本的目的。
但是基于蒙特卡洛的随机束搜索也面临着方差的问题。在每一步中它都是根据随机挑选K个不同词,它无法控制随机采样时的噪声,也就是样本分布的方差跟每一步的的方差相关,而的方差是无法控制的,它可能非常大也可能非常小。因此在基于蒙特卡洛的随机束搜索采集到的样本上估计的统计量会非常不稳定,比如在使用句子级损失的任务中采用样本估计损失的时候会计算出不稳定的值,使模型训练受到影响。
基于Gumbel-Top-K的随机束搜索
解决基于蒙特卡洛的随机束搜索的问题关键在于怎么控制每一步随机采样时的噪声。最近的论文提出使用了Gumbel-Top-K技巧来达到这个目的。
Gumbel-Top-K技巧
对于一个个类别的类别分布I
其中是第个类别的logit,如果我们对的每个类别的logit加入服从Gumbel分布的噪声G
如果从这个受到微小扰动的类别分布中取前K个概率最高,也就是logit最大的类别
那么我们可以保证这K个类别都服从于同时各不相同,同时噪声由Gumbel分布控制,即
自底向上的采样方法
如果我们把每个可能的句子当成一个单独的类别来构造一个类别数非常庞大(假设所有句子长度相等,那么有个类别,其中是词表大小,是句子长度)的类别分布,那么便可以使用Gumbel-Top-K技巧来从这一个庞大的类别分布中采集K个不同样本,同时每个样本都服从于原始的分布。这也是论文提出的自底向上的采样方法。
图5 自底向上的采样方法
图5展示了一个词表大小(hello,world,!),句子长度和样本数K=2的例子。我们需要先从第一个词开始枚举所有的9个可能的句子,同时使用模型计算这9个句子的概率。因为模型通常只能计算整个句子的概率,而Gumbel噪声需要加到整个logit上,我们可以使用整个句子的对数概率
作为整个句子的logit,然后把Gumbel噪声加到logit上
是句子受到Gumbel噪声扰动的对数概率(对数扰动概率),最后我们取其中最高的两个句子 worldhello和 world world,我们就完成了采样。
但是自顶向上的方法需要先枚举所有句子和计算其对数概率才能开始使用噪声扰动每个句子的对数概率,那么我们能不能从句子开始一边枚举一边计算和扰动生成的不同句子的对数概率?在此之前,我们必须先定义在枚举过程中中间生成的只有部分内容的句子的对数扰动概率。只有部分内容的句子(部分生成的句子)的对数扰动概率,比如例子中的 world,定义为以该部分生成的句子为前缀的所有完整句子中对数扰动概率最大的一个
其中是部分生成的句子的对数扰动概率,是以为前缀的一个完整句子的对数扰动概率,比如 world的对数扰动概率为 worldhello (-2.5), world ! (-3.2)和 worldworld (-1.2)各自的对数扰动概率中最大的一个(-1.2)。更进一步地,我们可以根据其孩子节点的对数扰动概率来递归地计算部分生成的句子的对数扰动概率:
自顶向下的采样方法
有了关于部分生成的句子的对数扰动概率还有它与其孩子节点之间的关系,那么我们可以想象,对于第一个词,因为它没有父亲节点同时对数概率为0,我们可以直接使用作为部分生成的句子的对数扰动概率。而对于中间的节点,因为受到和其孩子节点之间的关系的约束,因此从生成孩子节点的时候,所有孩子节点的对数扰动概率中的最大值必须等于,即我们往孩子节点的对数概率添加的噪声必须满足一定条件。
直接寻找这样的噪声是困难的,但是我们可以先直接在孩子节点的对数概率上添加噪声,得到,然后根据孩子节点与父亲节点之间的关系,做一个类似正规化的纠正操作来满足条件。最终孩子节点的对数扰动概率为
这样,我们可以一边枚举所有句子的同时计算句子的对数扰动概率。
更进一步地,我们可以看到,因为我们定义部分生成的句子的对数扰动概率为其对应的所有完整句子的最大的对数扰动概率,因此如果我们在枚举的时候只保留分数最高的K个候选,那么我们可以保证最终的K个候选一定是所有句子中分数最高的前K个,因为部分生成的句子的对数扰动概率的定义已经说明一个内部节点的所有叶子节点的对数扰动概率不可能比它的对数扰动概率大,因此在当前一层中不是分数最高的前K个的话以后它任何一个后代节点也不可能是分数最高的前K个。这样一个自顶向下的方法可以非常高效的采集K个不同样本而不需要枚举所有句子。
图6 自顶向下的采样方法
图6展示了一个K=2的自顶向下的采样例子。我们先对的对数概率进行扰动,得到-1.2,然后我们对所有候选序列 hello, !和 world的对数概率进行扰动并进行纠正,得到-4.3,-3.2,-1.2,最后我们只保留对数扰动概率最高的 !和 world继续进行拓展,最终得到 worldhello和 world world两个样本。
展望
最新提出的基于Gumbel-Top-K的随机束搜索提供了一种高效的采样手段。利用这种方法,我们可以:
对于需要采样来计算句子级损失的任务,可以更高效地训练模型;
类似于使用Gumbel-Softmax的梯度作为Gumbel-Max梯度的有偏估计,为Gumbel-Top-K寻找类似的梯度有偏估计,使得模型可以直接优化其搜索过程;
概率化束搜索,为束搜索可能导致的一系列问题如过翻译,漏译等提供概率解释。
参考文献
Kool, W., Hoof, H.V., & Welling, M.(2019). Stochastic Beams and Where To Find Them: The Gumbel-Top-k Trick for SamplingSequences Without Replacement. ICML.
Shen, S., Cheng, Y., He, Z., He, W., Wu,H., Sun, M., & Liu, Y. (2015). Minimum Risk Training for Neural MachineTranslation. ArXiv, abs/1512.02433.
作者介绍
李炎洋,东北大学自然语言处理实验室研究助理,研究方向:神经机器翻译。
单位介绍
东北大学自然语言处理实验室:东北大学自然语言处理实验室由姚天顺教授创建于 1980 年,现由朱靖波教授、肖桐博士领导,长期从事计算语言学的相关研究工作,主要包括机器翻译、语言分析、文本挖掘等。团队研发的支持119种语言互译的小牛翻译系统已经得到广泛应用。研究方向:神经机器翻译
领取专属 10元无门槛券
私享最新 技术干货