【新智元导读】Ian Goodfellow 提出令人惊叹的 GAN 用于无人监督的学习,是真正AI的“心头好”。而 PyTorch 虽然出世不久,但已俘获不少开发者。本文介绍如何在PyTorch中分5步、编写50行代码搞定GAN。下面一起来感受一下PyTorch的易用和强大吧。
2014年,Ian Goodfellow和他在蒙特利尔大学的同事们发表了一篇令人惊叹的论文,将GAN或称生成式对抗网络带到世界的面前。 通过计算图形和游戏理论的创新组合,他们指出,给定足够的建模能力,两个相互对抗的模型能够通过普通的旧的B-P网络进行共同训练。
模型扮演了两个不同的(确切地说,是对抗的)的角色。 给定一些真实数据集R,G是发生器(试图创建看起来像真正数据的假数据),而D是鉴别器,从真实数据集或G中获得数据并标记差异。 Goodfellow的比喻(一个很好的比喻)是,G像一伙努力用他们的输出匹配真实图景的骗子,而D是一帮努力鉴别差异的侦探。 (唯一的不同是,骗子G永远不会看到原始数据 –而只能看到D的判断。他们是一伙瞎了眼的骗子)。
理想状态下,D和G将随着时间的推移而变得更好,直到G真正变成了原始数据的“伪造大师”,而D则彻底迷失,“无法分辨真假”。
实际上,Goodfellow已经指出,G将能够对原始数据集执行一种无监督学习,找到某种(可能)维度低得多的方式来表示该数据的办法。正如Yann LeCun众所周知的表态,无人监督的学习是真正AI的“心头好”。
这种强大的技术似乎需要一吨的代码才可以开始,对吧?不。 使用PyTorch,我们实际上可以在50行代码下创建一个非常简单的GAN。 真的只有5个组件需要考虑:
R:原始的、真正的数据;
I:进入发生器作为熵源的随机噪声;
G:努力模仿原始数据的发生器;
D:努力将G从R中分辨出来的鉴别器;
训练循环,我们在其中教G来愚弄D,教D小心G。
1.)R:在我们的例子中,我们将从最简单的R- 一个钟形曲线开始。 此函数采用平均值和标准偏差,并返回一个函数,该函数从具有那些参数的正态分布中提供样本数据的正确形状。在我们的示例代码中,我们将使用平均值4.0和标准差1.25。
2.)I:进入生成器的输入也是随机的,但是为了使我们的工作更难一点,让我们使用一个均匀分布,而不是一个正常的分布。这意味着我们的模型G不能简单地移动/缩放输入以复制R,而是必须以非线性方式重塑数据。
3.)G:发生器是一个标准的前馈图 - 两个隐藏层,三个线性地图。我们使用ELU(exponential linear unit ),因为它们是the new black, yo。 G将从I获得均匀分布的数据样本,并以某种方式模仿来自R的正态分布样本。
4.)D:鉴别器代码与G的生成器代码非常相似;具有两个隐藏层和三个线性映射的前馈图。 它将从R或G获取样本,并将输出介于0和1之间的单个标量,解释为“假”与“真”。这就像一个神经网络可以得到的胆小鬼 。
5.) 最后,训练循环在两种模式之间交替:首先用准确的标签(把它当成是警察学院)训练在真实数据与假数据上训练D,; 然后用不准确的标签训练G来愚弄D。 这是善与恶之间的斗争。
即使你以前没有见过PyTorch,你也可以知道发生了什么。在第一(绿色)部分中,我们将两种类型的数据都推送到D,并对D的猜测和实际标签应用可区分的标准。这种推送是“向前”的步骤; 我们然后显式地调用'backward()',以便计算梯度,这会用于更新d_optimizer step()调用中的D参数。 我们在这里使用G,但不训练。
然后在最后一个(红色)部分,我们为G做同样的事情- 注意,我们还通过D运行G的输出(我们基本上是给了骗子一个侦探来让他练手),但在这一步我们不优化或改变D。 我们不想让侦探D学习错误的标签。 因此,我们只调用g_optimizer.step()。
这就是全部了。还有一些其他样板代码,但GAN特定的东西只是那5个组件,没有别的了。
在D和G之间几千次的禁忌之舞中,我们得到什么? 鉴别器D很快得到好处(而G缓慢进步着),但一旦达到一定的力量,G就有了一个配得上的对手,并开始改善。 真的改善。
20,000多个训练轮次之后,G输出平均值超过4.0,但随后回到一个相当稳定、正确的范围(下图左)。 同样,标准偏差最初错误的下降,但随后上升到我们希望的1.25的范围(下图右),匹配了R.
好,现在基本的统计和R匹配了。 那些highermoments怎么办? 分布的形状看上去正确吗? 毕竟,你当然可以有一个均值分布,平均值为4.0,标准差为1.25,但那并不会真正地和R匹配。让我们看看G最终发出的分布。
真不赖。 左尾比右边有点长,但我们应该说,偏斜和峭度是原始高斯的回归。
G几乎完全重现了原来的分布R,D则暗自神伤,因为他已无法分辨事实和虚幻。 这正是我们想要的结果(见Goodfellow中的图1)。 只用了不到50行的代码。
Goodfellow继续就GAN的问题发表了许多文章,其中包括一篇2016年的瑰宝,描述了一些实用的改进, 其中包括了此处适用的minibatchdiscrimination方法。 这里有一个2小时的教程,是他在2016年的NIPS提出的。对于TensorFlow的用户来说,这里有一个parallel post,来自GANs的Aylien。
好,说得够多了。去看看代码吧。
原文地址:
https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.cg0ofu1s5