首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

Pointer-network理论及tensorflow实战

数据下载地址:链接:https://pan.baidu.com/s/1nwJiu4T 密码:6joq

本文代码地址:https://github.com/princewen/tensorflow_practice/tree/master/myPtrNetwork

1、什么是pointer-network

Pointer Networks 是发表在机器学习顶级会议NIPS 2015上的一篇文章,其作者分别来自Google Brain和UC Berkeley。

Pointer Networks 也是一种seq2seq模型。他在attention mechanism的基础上做了改进,克服了seq2seq模型中“输出严重依赖输入”的问题。

什么是“输出严重依赖输入”呢?

论文里举了个例子,给定一些二维空间中[0,1]*[1,0]范围内的点,求这些点的凸包(convex hull)。凸包是凸优化里的重要概念,含义如下图所示,通俗来讲,即找到几个点能把所有点“包”起来。比如,模型的输入是序列,输出序列是凸包。到这里,“输出严重依赖输入”的意思也就明了了,即输出是从输入序列中提取出来的。换个输入,如,那么输出序列就是从里面选出来。用论文中的语言来描述,即和的凸包,输出分别依赖于输入的长度,两个问题求解的target class不一样,一个是7,另一个是1000。

Pointer Network在求凸包上的效果如何呢?

从Accuracy一栏可以看到,Ptr-net明显优于LSTM和LSTM+Attention。

为啥叫pointer network呢?

前面说到,对于凸包的求解,就是从输入序列中选点的过程。选点的方法就叫pointer,他不像attetion mechanism将输入信息通过encoder整合成context vector,而是将attention转化为一个pointer,来选择原来输入序列中的元素。

与attention的区别:如果你也了解attention的原理,可以看看pointer是如何修改attention的?如果不了解,这一部分就可以跳过了。

首先搬出attention mechanism的公式,前两个公式是整合encoder和decoder的隐式状态,学出来encoder、decoder隐式状态与当前输出的权重关系a,然后根据权重关系a和隐式状态e得到context vector用来预测下一个输出。

Pointer Net没有最后一个公式,即将权重关系a和隐式状态整合为context vector,而是直接进行通过softmax,指向输入序列选择中最有可能是输出的元素。

如果你对上面的理论还没有理解的很到位,那么我们通过代码来进一步讲解,相信你通过这段代码,可以对Ptr的理论有一个更深入的认识。

2、pointer-network实现

这段代码源自:https://github.com/devsisters/pointer-network-tensorflow

上面的代码 实现比较复杂,连下载数据的过程都有,真的是十分费劲,我直接把数据下载好了,上传到百度云上了,大家可以自行下载(地址见文章开头)。

代码目录如下:

config.py 定义了模型的配置

data_util.py 定义了数据处理过程

main.py 模型的主入口,定义了模型的训练过程

model.py 定义了我们的pointer-network模型

我们这里主要讲解我们的数据处理和模型定义两个文件

2.1 数据处理

好了,我们来看看我们的数据吧:

每行是一条数据,由于一条太长,所以分了三行显示。输入和target由output隔开,每个输入的点由两个坐标构成。

我们用下面的代码读入数据,这里,我们把最后一个target的最后一个去掉了,我们认为我们正常的target的输出序列不包含最后一个1,最后一个1作为结束标记在后面的代码里会加入。

由于每条记录的长度可能不同,因此,我们需要把所有数据的长度补成一样的:

2.2 模型建立

在model.py文件中,我们定义了Model类以及两个辅助的函数:

trainable_initial_state :建立可训练的lstm初始状态

index_matrix_to_pairs:这个主要是帮助我们使用gather_nd函数来选择输入的内容,该函数的一个简单处理效果如下:

我们这里重点讲解Model类的_build_model函数,该函数用来建立一个pointer-network模型。

定义输入

我们定义了四部分的输入,分别是encoder的输入及长度,decoder的预测序列及长度

输入处理

我们要对输入进行处理,将输入转换为embedding,embedding的长度和lstm的隐藏神经元个数相同。

在对输入进行处理之后,输入的形状就变为[batch , max_enc_seq_length, hidden_dim]

Encoder

根据配置中的lstm层数,我们建立encoder,同时将我们处理好的输入输入到模型中,得到encoder的输出以及encoder的最终状态。:

在得到输出之后,我们要给最前面的输出添加一个开始的输出,同时这个添加的开始的输出还将作为encoder的最开始的输入。看下面的图片:

training decoder

与seq2seq不同的是,pointer-network的输入并不是target序列的embedding,而是根据target序列的值选择相应位置的encoder的输出。

我们知道encoder的输出长度在添加了开始输出之后形状为[batch ,max_enc_seq_length + 1]。现在假设我们拿第一条记录进行训练,第一条记录的预测序列是[1,2,4],那么decoder依次的输入是

self.enc_outputs[0][0], self.enc_outputs[0][1],self.enc_outputs[0][2],self.enc_outputs[0][4],那么如何根据target序列来选择encoder的输出呢,这里就要用到我们刚刚定义的index_matrix_to_pairs函数以及gather_nd函数:

由于decoder的输出变成了原先的target序列的长度+1,因此我们要在每个target后面补充一个结束标记,我们补充1作为结束标记:

同样,我们建立一个多层的lstm网络:

对于decoder来说,这里我们每次每个batch只输入一个值,然后使用循环来实现整个decoder的过程:

可以看到,我们定义了两个数组来保存输出的序列,以及每次输出的softmax值。这里定义了一个choose_index函数,这个函数的作用即我们的pointer机制,即得到每个decoder输出与encoder输出按如下公式相互作用的softmax数组:

在论文中还提到一个词叫做glimpse function,他首先将上面式子中的q进行了处理,公式如下:

glimpse function可以实现多层,当然我们代码里只有一层:

可以看到,我们的attention函数高度还原了上面的式子,哈哈!

decoder predicting

对预测来说,我们不能实现选择好用哪个encoder的输出,必须根据上一轮的输出来决定,所以与training的代码不同的是,我们在每层循环里都是用index_matrix_to_pairs函数以及gather_nd函数来选择下一时刻的输出。

** 定义loss**

这里定义的loss与seq2seq的loss相似:

定义训练、验证、预测函数

我们还定义了训练、验证、预测函数:

2.3 训练及验证

在main.py函数中,我们获取数据,并进行训练。这里代码还有待完善,因为没有进行预测,嘻嘻!

实验效果如下:

3、 参考文献

1、神经网络之Pointer Net (Ptr-net) :https://zhuanlan.zhihu.com/p/30860157

2、https://github.com/devsisters/pointer-network-tensorflow

微信公众号

  • 发表于:
  • 原文链接http://kuaibao.qq.com/s/20180302G0203X00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券