首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >Tensorflow实战(3)-Batch Normalization实现

Tensorflow实战(3)-Batch Normalization实现

作者头像
YoungTimes
发布2022-04-28 13:00:11
发布2022-04-28 13:00:11
3550
举报

tf.nn.moments函数

函数定义如下:

代码语言:javascript
复制
def moments(x, axes, name=None, keep_dims=False)
函数的输入

x: 输入数据,格式一般为:[batchsize, height, width, kernels] axes: List,在哪个维度上计算,比如:[0, 1, 2] name: 操作的名称 keep_dims: 是否保持维度

函数的输出

mean: 均值 variance: 方差

使用举例
代码语言:javascript
复制
img = tf.Variable(tf.random_normal([128, 32, 32, 64]))
axis = list(range(len(img.get_shape()) - 1))
mean, variance = tf.nn.moments(img, axis)

tf.nn.batch_normalization函数

代码语言:javascript
复制
def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)

在使用batch_normalization的时候,需要去除网络中的bias

函数的输入

x: 输入的Tensor数据 mean: Tensor的均值 variance: Tensor的方差 offset: offset Tensor, 一般初始化为0,可训练 scale: scale Tensor,一般初始化为1,可训练 variance_epsilon: 一个小的浮点数,避免除数为0,一般取值0.001 name: 操作的名称

算法原理

使用示例

代码语言:javascript
复制
def conv_layer(prev_layer, layer_depth, is_training):

    strides = 2 if layer_depth % 3 == 0 else 1

    in_channels = prev_layer.get_shape().as_list()[3]
    out_channels = layer_depth*4

    weights = tf.Variable(
        tf.truncated_normal([3, 3, in_channels, out_channels], stddev=0.05))

    layer = tf.nn.conv2d(prev_layer, weights, strides=[1,strides, strides, 1], padding='SAME')

    gamma = tf.Variable(tf.ones([out_channels]))
    beta = tf.Variable(tf.zeros([out_channels]))

    pop_mean = tf.Variable(tf.zeros([out_channels]), trainable=False)
    pop_variance = tf.Variable(tf.ones([out_channels]), trainable=False)

    epsilon = 1e-3

    def batch_norm_training():
        batch_mean, batch_variance = tf.nn.moments(layer, [0,1,2], keep_dims=False)

        decay = 0.99
        train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
        train_variance = tf.assign(pop_variance, pop_variance * decay + batch_variance * (1 - decay))

        with tf.control_dependencies([train_mean, train_variance]):
            return tf.nn.batch_normalization(layer, batch_mean, batch_variance, beta, gamma, epsilon)

    def batch_norm_inference():
        return tf.nn.batch_normalization(layer, pop_mean, pop_variance, beta, gamma, epsilon)

    batch_normalized_output = tf.cond(is_training, batch_norm_training, batch_norm_inference)
    return tf.nn.relu(batch_normalized_output)
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-06-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 半杯茶的小酒杯 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • tf.nn.moments函数
    • 函数的输入
    • 函数的输出
    • 使用举例
  • tf.nn.batch_normalization函数
    • 函数的输入
    • 算法原理
  • 使用示例
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档