sampled softmax原论文:On Using Very Large Target Vocabulary for Neural Machine Translation 以及tensorflow关于candidate sampling的文档:candidate sampling
在神经机器翻译中,训练的复杂度以及解码的复杂度和词汇表的大小成正比。当输出的词汇表巨大时,传统的softmax由于要计算每一个类的logits就会有问题。在论文Neural Machine Translation by Jointly Learning to Align and Translate 中,带有attention的decoder中权重的公式如下:
因为我们输出的是一个概率值,所以(6)式的归一化银子ZZ的计算就需要将词汇表当中的logits都计算一遍,这个代价是很大的。 基于此,作者提出了一种采样的方法,使得我们在训练的时候,输出为原来输出的一个子集。(关于其它的解决方法,作者也有提,感兴趣的可以看原文,本篇博客只关注Sampled Softmax)
(感觉还是tensorflow文档说的清楚一点,最初看论文的时候还以为是相当于把一个单词划分到最近的一个类,那样的话,应该会有不同类别的关系啊不然也不make sense啊,但是看tensorflow源码就只有采样的过程啊,笑cry)
def sampled_softmax_loss(weights,
biases,
labels,
inputs,
num_sampled, # 每一个batch随机选择的类别
num_classes, # 所有可能的类别
num_true=1, #每一个sample的类别数量
sampled_values=None,
remove_accidental_hits=True,
partition_strategy="mod",
name="sampled_softmax_loss"):
tensorflow对于使用的建议:仅仅在训练阶段使用,在inference或者evaluation的时候还是需要使用full softmax。
原文: This operation is for training only. It is generally an underestimate of the full softmax loss. A common use case is to use this method for training, and calculate the full softmax loss for evaluation or inference.
这个函数的主体主要调用了另外一个函数:
logits, labels = _compute_sampled_logits(
weights=weights,
biases=biases,
labels=labels,
inputs=inputs,
num_sampled=num_sampled,
num_classes=num_classes,
num_true=num_true,
sampled_values=sampled_values,
subtract_log_q=True,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)
上述函数的返回值shape为:[batch_size, num_true + num_sampled]即可能的class为: Si∪tiS_i \cup{t_i} 而这个函数采样集合的代码如下:
sampled_values=candidate_sampling_ops.log_uniform_candidate_sampler(
true_classes=labels,# 真实的label
num_true=num_true,
num_sampled=num_sampled, # 需要采样的子集大小
unique=True,
range_max=num_classes)
而这个函数主要是按照log-uniform distribution(Zipfian distribution)来采样出一个子集,Zipfian distribution 即Zipf法则,以下为Wikipedia关于Zipf’s law的解释:
Zipf’s law states that given some corpus of natural language utterances, the frequency of any word is inversely proportional to its rank in the frequency table.