前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >如何实现自然语言处理的集束搜索解码器

如何实现自然语言处理的集束搜索解码器

作者头像
Hi胡瀚
发布2018-02-12 14:35:47
2.1K0
发布2018-02-12 14:35:47

自然语言处理任务(例如字幕生成和机器翻译)涉及生成单词序列。

针对这些问题开发的模型通常通过在输出词的词汇表中生成概率分布来运行,并且需要解码算法来对概率分布进行采样以生成最可能的词序列。

在本教程中,您将发现可用于文本生成问题的贪婪搜索和波束搜索解码算法。

完成本教程后,您将知道:

  • 文本生成问题的解码问题。
  • 贪婪的搜索解码器算法,以及如何在Python中实现它。
  • 集束搜索解码器算法,以及如何在Python中实现它。

让我们开始吧。

生成文本的解码器

在字幕生成,文本摘要和机器翻译等自然语言处理任务中,所需的预测是一系列单词。

为这些类型的问题开发的模型通常为输出一个每个单词在可能的词汇表中的概率分布。然后由解码器处理将概率转换为最终的单词序列。

在处理生成文本作为输出的自然语言处理任务的循环神经网络时,您可能会遇到这种情况。神经网络模型中的最后一层对于输出词汇表中的每个单词都有一个神经元,并且使用softmax激活函数来输出词汇表中每个单词作为序列中下一个单词的可能性。

解码最可能的输出序列包括根据它们的可能性搜索所有可能的输出序列。词汇的大小通常是数十万甚至数十万个单词,甚至数百万个单词。因此,搜索问题在输出序列的长度上是指数的,并且是难以处理的(NP-complete)来完全搜索。

实际上,启发式搜索方法被用于为给定预测返回一个或多个近似或“足够好”的解码输出序列。

由于搜索图的大小在源句子长度上是指数的,所以我们必须使用近似来有效地找到解。

- 引用出自《自然语言处理和机器翻译手册》第272页。

候选字词序列根据其可能性进行评分。通常使用贪婪搜索或集束搜索来定位文本的候选序列。我们将在这篇文章中看看这两种解码算法。

每个单独的预测都有一个相关的分数(或概率),我们对最大分数(或最大概率)的输出序列感兴趣。[...]一个流行的近似技术是使用贪婪预测,在每个阶段得到最高得分项。虽然这种方法通常是有效的,但显然是非最优的。事实上,使用集束搜索作为一个近似的搜索往往比贪婪的方法更好。

- 引用出自《自然语言处理中的神经网络方法》页227。

贪婪的搜索解码器

一个简单的近似方法是使用贪婪搜索,在输出序列的每一步选择最可能的单词。

这种方法的好处是速度非常快,但最终输出序列的质量可能不是最优的。

我们可以用Python中的一个小例子演示贪婪的搜索方法。

我们可以从一个包含10个单词序列的预测问题开始。每个单词被预测为5个可能单词的概率分布。

代码语言:js
复制
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)

我们将假定这些单词已经被整数编码了,这样列索引可以用来查找词汇表中的关联单词。因此,解码任务成为从概率分布中选择一个整数序列的任务。

所述argmax()的数学函数可以被用于选择具有最大值的阵列的索引。我们可以使用这个函数来选择序列中每个步骤最有可能的单词索引。这个功能是直接在numpy中提供的。

下面的greedy_decoder()函数使用argmax函数实现这个解码器策略。

代码语言:js
复制
# greedy decode
def greedy_decoder(data):
    # index for largest probability each row
    return [argmax(s) for s in data]

综上所述,下面列出了演示贪婪解码器的完整示例。

代码语言:js
复制
from numpy import array
from numpy import argmax
 
# greedy decode
def greedy_decoder(data):
    # index for largest probability each row
    return [argmax(s) for s in data]
 
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# decode sequence
result = greedy_decoder(data)
print(result)

运行示例输出一个整数序列,然后可以将其映射回词汇表中的单词。

代码语言:js
复制
[4, 0, 4, 0, 4, 0, 4, 0, 4, 0]

光束搜索解码器

另一种受欢迎的启发式方法是在贪婪搜索时扩展的集束搜索,并返回最可能的输出序列列表。

集束方法作为代替的贪婪选择最可能的下一步骤的序列并扩展所有的可能,并保持ķ的值,其中ķ是一个用户指定的参数,并通过序列控制光束或并行搜索的次数概率。

本地波束搜索算法跟踪k个状态,而不仅仅是一个。它从k个随机生成的状态开始。在每一步中,所有k个状态的所有后继都被生成。如果任何一个是目标,算法就会停止。否则,从完整列表中选择k个最佳继任者并重复。

- 第125-126页,人工智能:现代方法(第3版),2009年。

我们不需要从随机状态开始; 相反,我们从k个最可能的单词开始,作为序列中的第一步。

对于贪婪搜索,常见波束宽度值为1,对于机器翻译中的常见基准测试问题,值为5或10。由于多个候选序列增加了更好地匹配目标序列的可能性,较大的波束宽度导致模型的更好的性能。性能的提高会导致解码速度的降低。

在NMT中,通过简单的波束搜索解码器翻译新的句子,该解码器发现近似最大化训练的NMT模型的条件概率的翻译。波束搜索策略在每个时间步骤保持固定数目(波束)的活动候选者,从左到右逐字地生成翻译单词。通过增加光束尺寸,翻译性能可以增加,但代价是显着降低解码器的速度。

- 2017年神经机器翻译的束搜索策略

搜索过程可以通过达到最大长度,通过达到序列结束标记或者达到阈值可能性来分别停止每个候选者。

我们来举个具体的例子。

我们可以定义一个函数来执行给定的概率序列和波束宽度参数k的波束搜索。在每个步骤中,每个候选序列都被扩展为所有可能的后续步骤。每个候选步骤通过将概率相乘在一起进行评分。选择具有最可能概率的k个序列,并且修剪所有其他候选者。该过程然后重复,直到序列结束。

概率是小数,小数乘以小数。为了避免浮点数的下溢,概率的自然对数相乘在一起,使数字更大,更易于计算。此外,通过最小化分数来执行搜索也是常见的做法,因此,我们乘以概率的负对数。这个最后的调整意味着我们可以按照他们的分数从小到大的顺序对所有的候选序列进行排序,并选择第一个k作为最可能的候选序列。

下面的beam_search_decoder()函数实现了波束搜索解码器。

代码语言:js
复制
# beam search
def beam_search_decoder(data, k):
    sequences = [[list(), 1.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        # expand each current candidate
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score * -log(row[j])]
                all_candidates.append(candidate)
        # order all candidates by score
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        # select k best
        sequences = ordered[:k]
    return sequences

我们可以将它与上一节的样本数据结合在一起,这次返回最可能的3个序列。

代码语言:js
复制
from math import log
from numpy import array
from numpy import argmax
 
# beam search
def beam_search_decoder(data, k):
    sequences = [[list(), 1.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        # expand each current candidate
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score * -log(row[j])]
                all_candidates.append(candidate)
        # order all candidates by score
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
        # select k best
        sequences = ordered[:k]
    return sequences
 
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# decode sequence
result = beam_search_decoder(data, 3)
# print result
for seq in result:
    print(seq)

序列及其日志的可能性。

试用不同的k值。

代码语言:js
复制
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108]
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397]
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]

进一步阅读

如果您想深入了解,此处将提供更多有关该主题的资源。

概要

在本教程中,您发现了可用于文本生成问题的贪婪搜索和波束搜索解码算法。

具体来说,你了解到:

  • 文本生成问题的解码问题。
  • 贪婪的搜索解码器算法,以及如何在Python中实现它。
  • 集束搜索解码器算法,以及如何在Python中实现它。
评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 生成文本的解码器
  • 贪婪的搜索解码器
  • 光束搜索解码器
  • 概要
相关产品与服务
机器翻译
机器翻译(Tencent Machine Translation,TMT)结合了神经机器翻译和统计机器翻译的优点,从大规模双语语料库自动学习翻译知识,实现从源语言文本到目标语言文本的自动翻译,目前可支持十余种语言的互译。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档