Pointer Network是seq2seq模型的一种变型。seq2seq模型是一种编码-解码框架的端到端生成模型,已经在机器翻译、对话生成、语法改错等领域有了成功的进展。本文不再赘述。此处主要介绍Pointer Network的基本原理和作用。
Pointer Network的主要作用
Pointer Network主要用于解决组合优化问题,传统的优化问题寻优一般使用启发式的搜索算法,基于Pointer Network主要是对源数据进行组合,达到目标函数最优。常见的应用包括凸包问题、旅行商问题等。
Pointer Network的模型框架
对于凸包问题,可以简述为:可定图中若干点,选取其中几个连接成凸多边形使得该多边形能包含图中所有的点。
(1)如果该问题使用普通的seq2seq建模,即encoder输入序列为P1,..,P4点的坐标,decoder输出为点的label(即1-4),此时decoder的输出范围与encoder的输入不相同,只需要输出encoder序列的位置。目标函数如下所示:
(2)若使用seq2seq+ Attention模型,即解码过程需要对encoder端进行Attention计算,具体计算如下:
(3)区别于seq2seq +Attention模型,Pointer Network直接使用Attention的权重信息作为位置重要性的概率分布输出
简化了seq2seq+Attention的计算,无需将encoder端的编码求和后输入到LSTM cell再求输出并取softmax得到概率分布,简化了计算。
Pointer Network的成功应用
(1)组合优化问题: TSP问题等;
(2)阅读理解问题: 将原文进行编码,从原文中找到问题答案的起始位置和结束位置;
(3)摘要生成问题:从长文本中找到摘要句子的起始结束位置。
参考文献
Vinyals O, Fortunato M, Jaitly N. Pointer networks[C]// International Conference on Neural Information Processing Systems. MIT Press, 2015.