前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >【收藏】简单易用 TensorFlow 代码集,GAN通用框架、函数

【收藏】简单易用 TensorFlow 代码集,GAN通用框架、函数

作者头像
新智元
发布2019-03-07 14:08:15
1.1K0
发布2019-03-07 14:08:15
举报
文章被收录于专栏:新智元


新智元报道

来源:GitHub 作者:Junho Kim

【新智元导读】今天为大家推荐一个实用的GitHub项目:TensorFlow-Cookbook。 这是一个易用的TensorFlow代码集,包含了对GAN有用的一些通用架构和函数。

今天为大家推荐一个实用的GitHub项目:TensorFlow-Cookbook

这是一个易用的TensorFlow代码集,作者是来自韩国的AI研究科学家Junho Kim,内容涵盖了谱归一化卷积、部分卷积、pixel shuffle、几种归一化函数、 tf-datasetAPI,等等。

作者表示,这个repo包含了对GAN有用的一些通用架构和函数。

项目正在进行中,作者将持续为其他领域添加有用的代码,目前正在添加的是 tf-Eager mode的代码。欢迎提交pull requests和issues。

Github地址 :

https://github.com/taki0112/Tensorflow-Cookbook

如何使用

Import

  • ops.py
    • operations
    • from ops import *
  • utils.py
    • image processing
    • from utils import *

Network template

代码语言:javascript
复制
def network(x, is_training=True, reuse=False, scope="network"):    with tf.variable_scope(scope, reuse=reuse):
        x = conv(...)        
        ...
        
        return logit

使用DatasetAPI向网络插入数据

代码语言:javascript
复制
Image_Data_Class = ImageData(img_size, img_ch, augment_flag)
trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=16)
trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()

trainA_iterator = trainA.make_one_shot_iterator()
data_A = trainA_iterator.get_next()
logit = network(data_A)
  • 了解更多,请阅读: https://github.com/taki0112/Tensorflow-DatasetAPI

Option

  • padding='SAME'
    • pad = ceil[ (kernel - stride) / 2 ]
  • pad_type
    • 'zero' or 'reflect'
  • sn
    • use spectral_normalization or not
  • Ra
    • use relativistic gan or not
  • loss_func
    • gan
    • lsgan
    • hinge
    • wgan
    • wgan-gp
    • dragan

注意

  • 如果你不想共享变量,请以不同的方式设置所有作用域名称。

权重(Weight)

代码语言:javascript
复制
weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001)
weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001)

初始化(Initialization)

  • Xavier : tf.contrib.layers.xavier_initializer()
  • He : tf.contrib.layers.variance_scaling_initializer()
  • Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
  • Truncated_normal : tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
  • Orthogonal : tf.orthogonal_initializer(1.0) / # if relu = sqrt(2), the others = 1.0

正则化(Regularization)

  • l2_decay : tf.contrib.layers.l2_regularizer(0.0001)
  • orthogonal_regularizer : orthogonal_regularizer(0.0001) & orthogonal_regularizer_fully(0.0001)

卷积(Convolution)

basic conv

代码语言:javascript
复制
x = conv(x, channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=True, scope='conv')

partial conv (NVIDIA Partial Convolution)

代码语言:javascript
复制
x = partial_conv(x, channels=64, kernel=3, stride=2, use_bias=True, padding='SAME', sn=True, scope='partial_conv')

dilated conv

代码语言:javascript
复制
x = dilate_conv(x, channels=64, kernel=3, rate=2, use_bias=True, padding='SAME', sn=True, scope='dilate_conv')

Deconvolution

basic deconv

代码语言:javascript
复制
x = deconv(x, channels=64, kernel=3, stride=2, padding='SAME', use_bias=True, sn=True, scope='deconv')

Fully-connected

代码语言:javascript
复制
x = fully_conneted(x, units=64, use_bias=True, sn=True, scope='fully_connected')

Pixel shuffle

代码语言:javascript
复制
x = conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_down')
x = conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_up')
  • down ===> [height, width] -> [height // scale_factor, width // scale_factor]
  • up ===> [height, width] -> [height * scale_factor, width * scale_factor]

Block

residual block

代码语言:javascript
复制
x = resblock(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block')
x = resblock_down(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_down')
x = resblock_up(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_up')
  • down ===> [height, width] -> [height // 2, width // 2]
  • up ===> [height, width] -> [height * 2, width * 2]

attention block

代码语言:javascript
复制
x = self_attention(x, channels=64, use_bias=True, sn=True, scope='self_attention')
x = self_attention_with_pooling(x, channels=64, use_bias=True, sn=True, scope='self_attention_version_2')

x = squeeze_excitation(x, channels=64, ratio=16, use_bias=True, sn=True, scope='squeeze_excitation')

x = convolution_block_attention(x, channels=64, ratio=16, use_bias=True, sn=True, scope='convolution_block_attention')

Normalization

Normalization

代码语言:javascript
复制
x = batch_norm(x, is_training=is_training, scope='batch_norm')
x = instance_norm(x, scope='instance_norm')
x = layer_norm(x, scope='layer_norm')
x = group_norm(x, groups=32, scope='group_norm')

x = pixel_norm(x)

x = batch_instance_norm(x, scope='batch_instance_norm')

x = condition_batch_norm(x, z, is_training=is_training, scope='condition_batch_norm'):

x = adaptive_instance_norm(x, gamma, beta):
  • 如何使用 condition_batch_norm,请参考:

https://github.com/taki0112/BigGAN-Tensorflow

  • 如何使用 adaptive_instance_norm,请参考:

https://github.com/taki0112/MUNIT-Tensorflow

Activation

代码语言:javascript
复制
x = relu(x)
x = lrelu(x, alpha=0.2)
x = tanh(x)
x = sigmoid(x)
x = swish(x)

Pooling & Resize

代码语言:javascript
复制
x = up_sample(x, scale_factor=2)

x = max_pooling(x, pool_size=2)
x = avg_pooling(x, pool_size=2)

x = global_max_pooling(x)
x = global_avg_pooling(x)

x = flatten(x)
x = hw_flatten(x)

Loss

classification loss

代码语言:javascript
复制
loss, accuracy = classification_loss(logit, label)

pixel loss

代码语言:javascript
复制
loss = L1_loss(x, y)
loss = L2_loss(x, y)
loss = huber_loss(x, y)
loss = histogram_loss(x, y)
  • histogram_loss 表示图像像素值在颜色分布上的差异。

gan loss

代码语言:javascript
复制
d_loss = discriminator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit)
g_loss = generator_loss(Ra=True, loss_func='wgan_gp', real=real_logit, fake=fake_logit)
  • 如何使用 gradient_penalty,请参考:

https://github.com/taki0112/BigGAN-Tensorflow/blob/master/BigGAN_512.py#L180

kl-divergence (z ~ N(0, 1))

代码语言:javascript
复制
loss = kl_loss(mean, logvar)

Author

Junho Kim

Github地址 :

https://github.com/taki0112/Tensorflow-Cookbook

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-02-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 新智元 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • Import
  • Network template
  • 使用DatasetAPI向网络插入数据
  • Option
  • 注意
  • 权重(Weight)
    • 初始化(Initialization)
      • 正则化(Regularization)
        • basic conv
          • partial conv (NVIDIA Partial Convolution)
            • dilated conv
              • basic deconv
              • Fully-connected
              • Pixel shuffle
                • residual block
                • Activation
                • Pooling & Resize
                  • classification loss
                    • pixel loss
                      • gan loss
                        • kl-divergence (z ~ N(0, 1))
                        • Author
                        领券
                        问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档