前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >结合代码讲解VAE-GAN比较透彻的一篇文章

结合代码讲解VAE-GAN比较透彻的一篇文章

作者头像
CreateAMind
发布2018-07-25 10:53:00
10K0
发布2018-07-25 10:53:00
举报
文章被收录于专栏:CreateAMind

前面介绍了VAE-GAN 论文:Autoencoding beyond pixels usingALearnedSimilarityMmetric及视频

这篇文章通过代码介绍了VAE-GAN,特色如下:

1 多GPU

2 学习rate动态改变!

3 隐变量空间可视化

4 特征向量代数计算

5 神经元激活可视化

6 训练学习快

效果:

微信代码格式不好看,可以阅读原文访问原文章:https://github.com/timsainb/Tensorflow-MultiGPU-VAE-GAN

Tensorflow Multi-GPU VAE-GAN implementation

How does a VAE-GAN work?

  • We have three networks, an Encoder, a Generator, and a Discriminator.
    • The Encoder learns to map input x onto z space (latent space)
    • The Generator learns to generate x from z space
    • The Discriminator learns to discriminate whether the image being put in is real, or generated

Diagram of basic network input and output

l_x_tilde and l_x here become layers of high level features that the discriminator learns.

  • we train the network to minimize the difference between the high level features of x and x_tilde
  • This is basically an autoencoder that works on high level features rather than pixels
  • Adding this autoencoder to a GAN helps to stabilize the GAN

Training

ref

Train Encoder on minimization of:

  • kullback_leibler_loss(z_x, gaussian)
  • mean_squared_error(l_x_tilde_, l_x)

Train Generator on minimization of:

  • kullback_leibler_loss(z_x, gaussian)
  • mean_squared_error(l_x_tilde_, l_x)
  • -1*log(d_x_p)

Train Discriminator on minimization of:

  • -1*log(d_x) + log(1 - d_x_p)

Which GPUs are we using?

  • Set gpus to a list of the GPUs you're using. The network will then split up the work between those gpus
代码语言:javascript
复制
gpus = [2] # Here I set CUDA to only see one GPU
os.environ["CUDA_VISIBLE_DEVICES"]=','.join([str(i) for i in gpus])
num_gpus = len(gpus) # number of GPUs to use

Reading the dataset from HDF5 format

  • open `makedataset.ipynb' for instructions on how to build the dataset
代码语言:javascript
复制

A data iterator for batching (drawn up by Luke Metz)

代码语言:javascript
复制
iter_ = data_iterator()
代码语言:javascript
复制
iter_ = data_iterator()
代码语言:javascript
复制
#face_batch, label_batch

Draw out the architecture of our network

  • Each of these functions represent the Encoder, Generator, and Discriminator described above.
  • It would be interesting to try and implement the inception architecture to do the same thing, next time around:
代码语言:javascript
复制
各个神经网络配置:

Defining the forward pass through the network 前向计算

代码语言:javascript
复制
def inference(x):
    """
    Run the models. Called inference because it does the same thing as tensorflow's cifar tutorial
    """
    z_p =  tf.random_normal((batch_size, hidden_size), 0, 1) # normal dist for GAN
    eps = tf.random_normal((batch_size, hidden_size), 0, 1) # normal dist for VAE

    with pt.defaults_scope(activation_fn=tf.nn.elu,
                               batch_normalize=True,
                               learned_moments_update_rate=0.0003,
                               variance_epsilon=0.001,
                               scale_after_normalization=True):

        with tf.variable_scope("enc"):         
                z_x_mean, z_x_log_sigma_sq = encoder(x) # get z from the input      
        with tf.variable_scope("gen"):
            z_x = tf.add(z_x_mean, 
                tf.mul(tf.sqrt(tf.exp(z_x_log_sigma_sq)), eps)) # grab our actual z
            x_tilde = generator(z_x)  
        with tf.variable_scope("dis"):   
            _, l_x_tilde = discriminator(x_tilde)
        with tf.variable_scope("gen", reuse=True):         
            x_p = generator(z_p)    
        with tf.variable_scope("dis", reuse=True):
            d_x, l_x = discriminator(x)  # positive examples              
        with tf.variable_scope("dis", reuse=True):
            d_x_p, _ = discriminator(x_p)  
        return z_x_mean, z_x_log_sigma_sq, z_x, x_tilde, l_x_tilde, x_p, d_x, l_x, d_x_p, z_p

ref:上面的计算变量和下图对应。

Loss - define our various loss functions

  • SSE - we don't actually use this loss (also its the MSE), its just to see how close x is to x_tilde
  • KL Loss - our VAE gaussian distribution loss.
  • D_loss - Our descriminator loss, how good the discriminator is at telling if something is real
  • G_loss - essentially the opposite of the D_loss, how good the generator is a tricking the discriminator
  • notice we clip our values to make sure learning rates don't explode
代码语言:javascript
复制

Average the gradients between towers

Plot network output

  • This is just my ugly function to regularly plot the output of my network - tensorboard would probably be a better option for this

.........................略

With your graph, define what a step is (needed for multi-gpu), and what your optimizers are for each of your networks

动态学习率;学习率后面动态生成

Run all of the functions we defined above

  • tower_grads_e defines the list of gradients for the encoder for each tower
  • For each GPU we grab parameters corresponding to each network, we then calculate the gradients, and add them to the twoers to be averaged

Now lets actually run our session

代码语言:javascript
复制
with graph.as_default():

    # Start the Session
    init = tf.initialize_all_variables()
    saver = tf.train.Saver() # initialize network saver
    sess = tf.InteractiveSession(graph=graph,config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
    sess.run(init)

Get some example data to do visualizations with

代码语言:javascript
复制
example_data, _ = iter_.next()
np.shape(example_data)
代码语言:javascript
复制
(32, 12288)

Initialize our epoch number, and restore a saved network by uncommening #tf.train...

代码语言:javascript
复制
epoch = 0
tf.train.Saver.restore(saver, sess, 'models/faces_multiGPU_64_0000.tfmod')

Now we actually run the network

  • Importantly, notice how we define the learning rates
    • we calculate the sigmoid of how the network has been performing, and squash the learning rate using a sigmoid based on that. So if the discriminator has been winning, it's learning rate will be low, and if the generator is winning, it's learning rate will be lower on the next batch. 学习率动态处理
    • e_current_lr = e_learning_rate*sigmoid(np.mean(d_real),-.5,10)

训练代码:

生成的图片

查看变量空间对应的显示图片内容

'Spike Triggered Average' style receptive fields.

代码略.......................

Now lets try some latent space algebra

.................代码略

加黄头发

............................

This implementation is based on a few other things:

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2016-11-29,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 CreateAMind 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Tensorflow Multi-GPU VAE-GAN implementation
    • How does a VAE-GAN work?
      • Diagram of basic network input and output
        • Training
          • Which GPUs are we using?
            • Reading the dataset from HDF5 format
              • A data iterator for batching (drawn up by Luke Metz)
                • Draw out the architecture of our network
                  • Defining the forward pass through the network 前向计算
                    • ref:上面的计算变量和下图对应。
                      • Loss - define our various loss functions
                        • Average the gradients between towers
                          • Plot network output
                            • With your graph, define what a step is (needed for multi-gpu), and what your optimizers are for each of your networks
                              • Run all of the functions we defined above
                                • Now lets actually run our session
                                  • Get some example data to do visualizations with
                                    • Initialize our epoch number, and restore a saved network by uncommening #tf.train...
                                      • Now we actually run the network
                                        • 'Spike Triggered Average' style receptive fields.
                                          • Now lets try some latent space algebra
                                            • This implementation is based on a few other things:
                                            领券
                                            问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档