首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >结合代码讲解VAE-GAN比较透彻的一篇文章

结合代码讲解VAE-GAN比较透彻的一篇文章

作者头像
CreateAMind
发布于 2018-07-25 02:53:00
发布于 2018-07-25 02:53:00
10.4K00
代码可运行
举报
文章被收录于专栏:CreateAMindCreateAMind
运行总次数:0
代码可运行

前面介绍了VAE-GAN 论文:Autoencoding beyond pixels usingALearnedSimilarityMmetric及视频

这篇文章通过代码介绍了VAE-GAN,特色如下:

1 多GPU

2 学习rate动态改变!

3 隐变量空间可视化

4 特征向量代数计算

5 神经元激活可视化

6 训练学习快

效果:

微信代码格式不好看,可以阅读原文访问原文章:https://github.com/timsainb/Tensorflow-MultiGPU-VAE-GAN

Tensorflow Multi-GPU VAE-GAN implementation

  • This is an implementation of the VAE-GAN based on the implementation described in Autoencoding beyond pixels using a learned similarity metric ref论文:Autoencoding beyond pixels usingALearnedSimilarityMmetric及视频
  • I implement a few useful things like
    • Visualizing Movement Through Z-Space 可视化
    • Latent Space Algebra 变量空间技术
    • Spike Triggered Average Style Receptive Fields 神经元激活区域

How does a VAE-GAN work?

  • We have three networks, an Encoder, a Generator, and a Discriminator.
    • The Encoder learns to map input x onto z space (latent space)
    • The Generator learns to generate x from z space
    • The Discriminator learns to discriminate whether the image being put in is real, or generated

Diagram of basic network input and output

l_x_tilde and l_x here become layers of high level features that the discriminator learns.

  • we train the network to minimize the difference between the high level features of x and x_tilde
  • This is basically an autoencoder that works on high level features rather than pixels
  • Adding this autoencoder to a GAN helps to stabilize the GAN

Training

ref

Train Encoder on minimization of:

  • kullback_leibler_loss(z_x, gaussian)
  • mean_squared_error(l_x_tilde_, l_x)

Train Generator on minimization of:

  • kullback_leibler_loss(z_x, gaussian)
  • mean_squared_error(l_x_tilde_, l_x)
  • -1*log(d_x_p)

Train Discriminator on minimization of:

  • -1*log(d_x) + log(1 - d_x_p)

Which GPUs are we using?

  • Set gpus to a list of the GPUs you're using. The network will then split up the work between those gpus
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
gpus = [2] # Here I set CUDA to only see one GPU
os.environ["CUDA_VISIBLE_DEVICES"]=','.join([str(i) for i in gpus])
num_gpus = len(gpus) # number of GPUs to use

Reading the dataset from HDF5 format

  • open `makedataset.ipynb' for instructions on how to build the dataset
代码语言:javascript
代码运行次数:0
运行
复制

A data iterator for batching (drawn up by Luke Metz)

  • https://indico.io/blog/tensorflow-data-inputs-part1-placeholders-protobufs-queues/
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
iter_ = data_iterator()
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
iter_ = data_iterator()
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
#face_batch, label_batch

Draw out the architecture of our network

  • Each of these functions represent the Encoder, Generator, and Discriminator described above.
  • It would be interesting to try and implement the inception architecture to do the same thing, next time around:
  • They describe how to implement inception, in prettytensor, here: https://github.com/google/prettytensor
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
各个神经网络配置:

Defining the forward pass through the network 前向计算

  • This function is based upon the inference function from tensorflows cifar tutorials
    • https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/image/cifar10/cifar10.py
  • Notice I use with tf.variable_scope("enc"). This way, we can reuse these variables using reuse=True. We can also specify which variables to train using which error functions based upon the label enc
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
def inference(x):
    """
    Run the models. Called inference because it does the same thing as tensorflow's cifar tutorial
    """
    z_p =  tf.random_normal((batch_size, hidden_size), 0, 1) # normal dist for GAN
    eps = tf.random_normal((batch_size, hidden_size), 0, 1) # normal dist for VAE

    with pt.defaults_scope(activation_fn=tf.nn.elu,
                               batch_normalize=True,
                               learned_moments_update_rate=0.0003,
                               variance_epsilon=0.001,
                               scale_after_normalization=True):

        with tf.variable_scope("enc"):         
                z_x_mean, z_x_log_sigma_sq = encoder(x) # get z from the input      
        with tf.variable_scope("gen"):
            z_x = tf.add(z_x_mean, 
                tf.mul(tf.sqrt(tf.exp(z_x_log_sigma_sq)), eps)) # grab our actual z
            x_tilde = generator(z_x)  
        with tf.variable_scope("dis"):   
            _, l_x_tilde = discriminator(x_tilde)
        with tf.variable_scope("gen", reuse=True):         
            x_p = generator(z_p)    
        with tf.variable_scope("dis", reuse=True):
            d_x, l_x = discriminator(x)  # positive examples              
        with tf.variable_scope("dis", reuse=True):
            d_x_p, _ = discriminator(x_p)  
        return z_x_mean, z_x_log_sigma_sq, z_x, x_tilde, l_x_tilde, x_p, d_x, l_x, d_x_p, z_p

ref:上面的计算变量和下图对应。

Loss - define our various loss functions

  • SSE - we don't actually use this loss (also its the MSE), its just to see how close x is to x_tilde
  • KL Loss - our VAE gaussian distribution loss.
    • See https://arxiv.org/abs/1312.6114
  • D_loss - Our descriminator loss, how good the discriminator is at telling if something is real
  • G_loss - essentially the opposite of the D_loss, how good the generator is a tricking the discriminator
  • notice we clip our values to make sure learning rates don't explode
代码语言:javascript
代码运行次数:0
运行
复制

Average the gradients between towers

  • This function is taken directly from
    • https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
  • Basically we're taking a list of gradients from each tower, and averaging them together

Plot network output

  • This is just my ugly function to regularly plot the output of my network - tensorboard would probably be a better option for this

.........................略

With your graph, define what a step is (needed for multi-gpu), and what your optimizers are for each of your networks

动态学习率;学习率后面动态生成

Run all of the functions we defined above

  • tower_grads_e defines the list of gradients for the encoder for each tower
  • For each GPU we grab parameters corresponding to each network, we then calculate the gradients, and add them to the twoers to be averaged

Now lets actually run our session

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
with graph.as_default():

    # Start the Session
    init = tf.initialize_all_variables()
    saver = tf.train.Saver() # initialize network saver
    sess = tf.InteractiveSession(graph=graph,config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
    sess.run(init)

Get some example data to do visualizations with

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
example_data, _ = iter_.next()
np.shape(example_data)
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
(32, 12288)

Initialize our epoch number, and restore a saved network by uncommening #tf.train...

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
epoch = 0
tf.train.Saver.restore(saver, sess, 'models/faces_multiGPU_64_0000.tfmod')

Now we actually run the network

  • Importantly, notice how we define the learning rates
    • we calculate the sigmoid of how the network has been performing, and squash the learning rate using a sigmoid based on that. So if the discriminator has been winning, it's learning rate will be low, and if the generator is winning, it's learning rate will be lower on the next batch. 学习率动态处理
    • e_current_lr = e_learning_rate*sigmoid(np.mean(d_real),-.5,10)

训练代码:

生成的图片

查看变量空间对应的显示图片内容

'Spike Triggered Average' style receptive fields.

代码略.......................

Now lets try some latent space algebra

.................代码略

加黄头发

............................

This implementation is based on a few other things:

  • Autoencoding beyond pixels (Github)
  • VAE and GAN implementations in prettytensor/tensorflow (Github)
  • Tensorflow VAE tutorial
  • DCGAN (Github)
  • Torch GAN tutorial (Github)
  • Open AI improving GANS (Github)
  • Other things which I am probably forgetting...
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2016-11-29,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 CreateAMind 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Generative Adversarial Network
这里我们将建立 一个对抗生成网络 (GAN)训练MNIST,并在最后生成新的手写数字。
小飞侠xp
2018/08/29
4080
tensorflow 实现wgan-gp mnist图片生成
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_25737169/article/details/76695935
DoubleV
2018/09/12
1.6K0
tensorflow 实现wgan-gp  mnist图片生成
一个很牛的GAN工具项目:HyperGAN
A versatile GAN(generative adversarial network) implementation focused on scalability and ease-of-use.
CreateAMind
2018/07/24
9830
一个很牛的GAN工具项目:HyperGAN
在TensorFlow中对比两大生成模型:VAE与GAN(附测试代码)
来源:机器之心 本文长度为3071字,建议阅读6分钟 本文在 MNIST 上对VAE和GAN这两类生成模型的性能进行了对比测试。 项目链接:https://github.com/kvmanohar22/ Generative-Models 变分自编码器(VAE)与生成对抗网络(GAN)是复杂分布上无监督学习最具前景的两类方法。 本项目总结了使用变分自编码器(Variational Autoencode,VAE)和生成对抗网络(GAN)对给定数据分布进行建模,并且对比了这些模型的性能。你可能会问:我们已经
数据派THU
2018/01/30
2.8K0
在TensorFlow中对比两大生成模型:VAE与GAN(附测试代码)
【深度学习】--GAN从入门到初始
GAN,生成对抗网络,在2016年基本火爆深度学习,所有有必要学习一下。生成对抗网络直观的应用可以帮我们生成数据,图片。
LhWorld哥陪你聊算法
2018/09/13
6110
【深度学习】--GAN从入门到初始
在TensorFlow中对比两大生成模型:VAE与GAN
选自GitHub 机器之心编译 参与:路雪、李泽南 变分自编码器(VAE)与生成对抗网络(GAN)是复杂分布上无监督学习最具前景的两类方法。本文中,作者在 MNIST 上对这两类生成模型的性能进行了对比测试。 项目链接:https://github.com/kvmanohar22/Generative-Models 本项目总结了使用变分自编码器(Variational Autoencode,VAE)和生成对抗网络(GAN)对给定数据分布进行建模,并且对比了这些模型的性能。你可能会问:我们已经有了数百万张图像
机器之心
2018/05/10
8640
tf24: GANs—生成明星脸
本文介绍了如何使用TensorFlow实现生成对抗网络(GANs),用于生成明星脸。首先,介绍了TensorFlow的基本概念,然后详细阐述了如何搭建一个GANs模型。接着,展示了如何训练模型以及使用GANs进行图像生成。最后,总结了本文的主要内容和实现步骤。
MachineLP
2018/01/09
1.2K0
tf24: GANs—生成明星脸
Python让你成为AI 绘画大师,简直太惊艳了!(附代码))
引言:基于前段时间我在CSDN上创作的文章“CylcleGAN人脸转卡通图”的不足,今天给大家分享一个更加完美的绘制卡通的项目“Learning to Cartoonize Using White-box Cartoon Representations”。
AI科技大本营
2020/09/22
2.6K0
Python让你成为AI 绘画大师,简直太惊艳了!(附代码))
通过 VAE、GAN 和 Transformer 释放生成式 AI
生成式人工智能是人工智能和创造力交叉的一个令人兴奋的领域,它通过使机器能够生成新的原创内容,正在给各个行业带来革命性的变化。从生成逼真的图像和音乐作品到创建逼真的文本和沉浸式虚拟环境,生成式人工智能正在突破机器所能实现的界限。在这篇博客中,我们将探索使用 VAE、GAN 和 Transformer 的生成式人工智能的前景,深入研究其应用、进步及其对未来的深远影响。
磐创AI
2023/11/08
9090
通过 VAE、GAN 和 Transformer 释放生成式 AI
TensorFlow 1.x 深度学习秘籍:11~14
在本章中,我们将讨论如何将生成对抗网络(GAN)用于深度学习领域,其中关键方法是训练图像生成器来挑战鉴别器,并同时训练鉴别器来改进生成器。 可以将相同的方法应用于不同于图像领域。 另外,我们将讨论变分自编码器。
ApacheCN_飞龙
2023/04/23
1.2K0
TensorFlow 1.x 深度学习秘籍:11~14
手把手教你用GAN实现半监督学习
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_25737169/article/details/78532719
DoubleV
2018/09/12
1.7K0
手把手教你用GAN实现半监督学习
教程 | 用AI生成猫的图片,撸猫人士必备
编译 | 小梁 【AI科技大本营导读】我们身边总是不乏各种各样的撸猫人士,面对朋友圈一波又一波晒猫的浪潮,作为学生狗和工作狗的我们只有羡慕的份,更流传有“吸猫穷三代,撸猫毁一生?”的名言,今天营长就为
AI科技大本营
2018/04/26
2.3K0
教程 | 用AI生成猫的图片,撸猫人士必备
Tensorflow 2.0 的这些新设计,你适应好了吗?
如果说两代 Tensorflow 有什么根本不同,那应该就是 Tensorflow 2.0 更注重使用的低门槛,旨在让每个人都能应用机器学习技术。考虑到它可能会成为机器学习框架的又一个重要里程碑,本文会介绍 1.x 和 2.x 版本之间的所有(已知)差异,重点关注它们之间的思维模式变化和利弊关系。
崔庆才
2019/09/04
1K0
Tensorflow 2.0 的这些新设计,你适应好了吗?
TensorFlow 卷积神经网络实用指南:6~10
本章将介绍一种与到目前为止所看到的模型稍有不同的模型。 到目前为止提供的所有模型都属于一种称为判别模型的模型。 判别模型旨在找到不同类别之间的界限。 他们对找到P(Y|X)-给定某些输入X的输出Y的概率感兴趣。 这是用于分类的自然概率分布,因为您通常要在给定一些输入X的情况下找到标签Y。
ApacheCN_飞龙
2023/04/23
7520
图像生成:GAN
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
chaibubble
2019/09/18
1.1K0
图像生成:GAN
谷歌开源的 GAN 库--TFGAN
本文大约 8000 字,阅读大约需要 12 分钟 第一次翻译,限于英语水平,可能不少地方翻译不准确,请见谅!
kbsc13
2019/08/16
9190
一看就懂的Tensorflow实战(DCGAN)
DCGAN在GAN的基础上优化了网络结构,加入了 conv,batch_norm 等层,使得网络更容易训练,网络结构如下:
AI异构
2020/07/29
8490
一看就懂的Tensorflow实战(DCGAN)
【深度学习】生成对抗网络(GAN)
生成对抗网络(Generative Adversarial Networks)是一种无监督深度学习模型,用来通过计算机生成数据,由Ian J. Goodfellow等人于2014年提出。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。生成对抗网络被认为是当前最具前景、最具活跃度的模型之一,目前主要应用于样本数据生成、图像生成、图像修复、图像转换、文本生成等方向。
杨丝儿
2022/03/20
2.8K0
【深度学习】生成对抗网络(GAN)
利用tensorflow训练简单的生成对抗网络GAN
对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的。 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(discriminator)之间博弈的过程。整个网络训练的过程中,
狼啸风云
2020/09/27
1.3K0
TensorFlow-CIFAR10 CNN代码分析
想了解更多信息请参考CIFAR-10 page,以及Alex Krizhevsky写的技术报告
百川AI
2021/10/19
7140
相关推荐
Generative Adversarial Network
更多 >
LV.2
这个人很懒,什么都没有留下~
交个朋友
加入前端学习入门群
前端基础系统教学 经验分享避坑指南
加入腾讯云技术交流站
前端技术前沿探索 云开发实战案例分享
加入云开发企业交流群
企业云开发实战交流 探讨技术架构优化
换一批
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验