Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >TensorFlow中基于三维卷积的批量归一化

TensorFlow中基于三维卷积的批量归一化
EN

Stack Overflow用户
提问于 2017-01-24 06:28:44
回答 1查看 8.2K关注 0票数 17

我正在实现一个依赖于3D卷积的模型(对于类似于动作识别的任务),我想使用批处理规范化(参见[Ioffe & Szegedy 2015])。我找不到任何关于3D凸体的教程,所以我在这里做一个简短的教程,我想和你一起复习一下。

下面的代码引用了TensorFlow r0.12及其显式实例变量--我的意思是,除了tf.contrib.layers.batch_norm()函数之外,我不使用tf.contrib.layers.batch_norm。我这样做是为了更好地理解事物在幕后是如何工作的,并且有更多的实现自由度(例如,变量摘要)。

我将顺利地获得三维卷积的情况,首先写一个完全连接的层的例子,然后为一个2D卷积,最后为三维的情况。在查看代码时,如果您能够检查所有操作是否正确--代码运行,但我不能100%肯定我应用批处理规范化的方式,这将是很好的。我以一个更详细的问题结束这篇文章。

代码语言:javascript
运行
AI代码解释
复制
import tensorflow as tf

# This flag is used to allow/prevent batch normalization params updates
# depending on whether the model is being trained or used for prediction.
training = tf.placeholder_with_default(True, shape=())

全连通(FC)案

代码语言:javascript
运行
AI代码解释
复制
# Input.
INPUT_SIZE = 512
u = tf.placeholder(tf.float32, shape=(None, INPUT_SIZE))

# FC params: weights only, no bias as per [Ioffe & Szegedy 2015].
FC_OUTPUT_LAYER_SIZE = 1024
w = tf.Variable(tf.truncated_normal(
    [INPUT_SIZE, FC_OUTPUT_LAYER_SIZE], dtype=tf.float32, stddev=1e-1))

# Layer output with no activation function (yet).
fc = tf.matmul(u, w)

# Batch normalization.
fc_bn = tf.contrib.layers.batch_norm(
    fc,
    center=True,
    scale=True,
    is_training=training,
    scope='fc-batch_norm')

# Activation function.
fc_bn_relu = tf.nn.relu(fc_bn)
print(fc_bn_relu)  # Tensor("Relu:0", shape=(?, 1024), dtype=float32)

二维卷积(CNN)层实例

代码语言:javascript
运行
AI代码解释
复制
# Input: 640x480 RGB images (whitened input, hence tf.float32).
INPUT_HEIGHT = 480
INPUT_WIDTH = 640
INPUT_CHANNELS = 3
u = tf.placeholder(tf.float32, shape=(None, INPUT_HEIGHT, INPUT_WIDTH, INPUT_CHANNELS))

# CNN params: wights only, no bias as per [Ioffe & Szegedy 2015].
CNN_FILTER_HEIGHT = 3  # Space dimension.
CNN_FILTER_WIDTH = 3  # Space dimension.
CNN_FILTERS = 128
w = tf.Variable(tf.truncated_normal(
    [CNN_FILTER_HEIGHT, CNN_FILTER_WIDTH, INPUT_CHANNELS, CNN_FILTERS],
    dtype=tf.float32, stddev=1e-1))

# Layer output with no activation function (yet).
CNN_LAYER_STRIDE_VERTICAL = 1
CNN_LAYER_STRIDE_HORIZONTAL = 1
CNN_LAYER_PADDING = 'SAME'
cnn = tf.nn.conv2d(
    input=u, filter=w,
    strides=[1, CNN_LAYER_STRIDE_VERTICAL, CNN_LAYER_STRIDE_HORIZONTAL, 1],
    padding=CNN_LAYER_PADDING)

# Batch normalization.
cnn_bn = tf.contrib.layers.batch_norm(
    cnn,
    data_format='NHWC',  # Matching the "cnn" tensor which has shape (?, 480, 640, 128).
    center=True,
    scale=True,
    is_training=training,
    scope='cnn-batch_norm')

# Activation function.
cnn_bn_relu = tf.nn.relu(cnn_bn)
print(cnn_bn_relu)  # Tensor("Relu_1:0", shape=(?, 480, 640, 128), dtype=float32)

三维卷积(CNN3D)层壳

代码语言:javascript
运行
AI代码解释
复制
# Input: sequence of 9 160x120 RGB images (whitened input, hence tf.float32).
INPUT_SEQ_LENGTH = 9
INPUT_HEIGHT = 120
INPUT_WIDTH = 160
INPUT_CHANNELS = 3
u = tf.placeholder(tf.float32, shape=(None, INPUT_SEQ_LENGTH, INPUT_HEIGHT, INPUT_WIDTH, INPUT_CHANNELS))

# CNN params: wights only, no bias as per [Ioffe & Szegedy 2015].
CNN3D_FILTER_LENGHT = 3  # Time dimension.
CNN3D_FILTER_HEIGHT = 3  # Space dimension.
CNN3D_FILTER_WIDTH = 3  # Space dimension.
CNN3D_FILTERS = 96
w = tf.Variable(tf.truncated_normal(
    [CNN3D_FILTER_LENGHT, CNN3D_FILTER_HEIGHT, CNN3D_FILTER_WIDTH, INPUT_CHANNELS, CNN3D_FILTERS],
    dtype=tf.float32, stddev=1e-1))

# Layer output with no activation function (yet).
CNN3D_LAYER_STRIDE_TEMPORAL = 1
CNN3D_LAYER_STRIDE_VERTICAL = 1
CNN3D_LAYER_STRIDE_HORIZONTAL = 1
CNN3D_LAYER_PADDING = 'SAME'
cnn3d = tf.nn.conv3d(
    input=u, filter=w,
    strides=[1, CNN3D_LAYER_STRIDE_TEMPORAL, CNN3D_LAYER_STRIDE_VERTICAL, CNN3D_LAYER_STRIDE_HORIZONTAL, 1],
    padding=CNN3D_LAYER_PADDING)

# Batch normalization.
cnn3d_bn = tf.contrib.layers.batch_norm(
    cnn3d,
    data_format='NHWC',  # Matching the "cnn" tensor which has shape (?, 9, 120, 160, 96).
    center=True,
    scale=True,
    is_training=training,
    scope='cnn3d-batch_norm')

# Activation function.
cnn3d_bn_relu = tf.nn.relu(cnn3d_bn)
print(cnn3d_bn_relu)  # Tensor("Relu_2:0", shape=(?, 9, 120, 160, 96), dtype=float32)

我想确定的是,上面的代码是否准确地实现了批处理规范化,正如在证券交易委员会结束时在[Ioffe & Szegedy 2015]中描述的那样。3.2:

对于卷积层,我们还希望归一化服从卷积性质,从而使同一特征映射的不同元素在不同的位置以相同的方式规范化。为了实现这一点,我们联合标准化了所有小型批次中的所有激活,覆盖所有位置。..。阿尔格。2作了类似的修改,使得在推理过程中BN变换对给定特征映射中的每一次激活都应用相同的线性变换。

UPDATE --我想上面的代码对于3D conv情况也是正确的。事实上,当我定义我的模型时,如果我打印所有可训练变量,我也会看到预期的beta和gamma变量的数量。例如:

代码语言:javascript
运行
AI代码解释
复制
Tensor("conv3a/conv3d_weights/read:0", shape=(3, 3, 3, 128, 256), dtype=float32)
Tensor("BatchNorm_2/beta/read:0", shape=(256,), dtype=float32)
Tensor("BatchNorm_2/gamma/read:0", shape=(256,), dtype=float32)

这在我看来是可以的,因为由于BN,每一个特征图都要学习一对beta和gamma (总共256)。

Ioffe & Szegedy 2015:批量规范化:通过减少内部协变量转移加快深度网络培训

EN

回答 1

Stack Overflow用户

发布于 2017-10-07 05:07:41

这是一篇关于3D批范数的很棒的文章,人们常常没有注意到,批范数可以应用于任何等级大于1的张量。您的代码是正确的,但是我不得不在这里添加一些重要的注释:

  • 在tensorflow中,“标准”2D批规范(接受一个4D张量)比3D或更高的速度要快得多,因为它支持应用fused_batch_norm一核运算实现: 融合批处理规范将进行批处理规范化所需的多个操作合并到一个内核中。批处理规范是一个昂贵的过程,对某些模型来说,它占了操作时间的很大一部分。使用融合批处理范数可导致12%-30%的加速比。

关于GitHub的一个问题也支持3D过滤器,但目前还没有任何新的活动,目前这个问题还没有解决。

  • 尽管最初的文章规定在ReLU激活之前使用批处理规范(这就是您在上面的代码中所做的),但有证据表明,在激活之后使用批处理规范可能更好。以下是Francois对Keras GitHub的评论: ..。我可以保证,最近由基督教司仪写的代码适用于BN之前的relu。不过,这仍然是偶尔会引起争论的话题。
  • 对于任何有兴趣在实践中应用规范化思想的人来说,这一思想的最新研究发展,即权重归一化层归一化,弥补了原有批量规范的某些缺点,例如,它们更适合于LSTM和递归网络。
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41830723

复制
相关文章
Google JS API 授权 失败
// 初始化OAuth2.0授权 const authenticate = () => { return gapi.auth2.getAuthInstance() .signIn({scope: "https://www.googleapis.com/auth/documents https://www.googleapis.com/auth/drive https://www.googleapis.com/auth/driv
拿我格子衫来
2022/01/24
4.1K0
Google JS API 授权 失败
使用Google JS api 创建 文档
https://developers.google.com/docs/api/reference/rest/v1/documents/request#Request
拿我格子衫来
2022/01/24
3.3K0
使用Google JS api 创建 文档
Google短网址的API
除了速度快,goo.gl还提供详细的点击统计。比如,Yahoo首页的短网址是http://goo.gl/QuXj,那么它的统计数据就在http://goo.gl/info/QuXj。加上后缀".qr",还能得到这个网址的二维条形码,Yahoo的就是http://goo.gl/QuXj.qr。
ruanyf
2018/09/21
4.4K1
Google短网址的API
Google JavaScript API
You can use the JavaScript client library to interact with Google APIs, such as People, Calendar, and Drive, from your web applications. Follow the instructions on this page to get started.
拿我格子衫来
2022/01/24
6140
Google 发布 Google Friend Connect API
Google Friend Connect 是 Google 推出的社会化网络工具,通过此工具你可以将各种支持 OpenSocial 的应用通过 Google Friend Connect 在你的网站上应用,并且可以和已有的社会化网络进行整合应用。今天 Google 更是开放了 Google Friend Connect 的 API,让你能够访问到更多 Google Friend Connect 核心的数据和功能。 Google Friend Connect 提供两种 API,JavaScript API 允许你能够直接集成社会化社区到你的网页中。REST API 能够允许你把网站的现有的登陆系统和数据集成新的社会化数据和活动,并能实现让你的网站实现通过 Gmail 账号,Yahoo 账号,OpenID 等方式实现单点登录。
Denis
2023/04/14
6630
Google JavaScript API 的使用
您可以使用JavaScript客户端库与Web应用程序中的Google API(例如,人物,日历和云端硬盘)进行交互。请按照此页面上的说明进行操作。
拿我格子衫来
2022/01/24
3K0
google maps api_js调用谷歌浏览器接口
1. 使用谷歌地图 API 的第一步就是要注册一个 API 密钥,需要注重一下两点:
全栈程序员站长
2022/09/20
5.8K0
JavaScript---网络编程(7)-Dom模型(节点间的层次关系,节点的增、删、改)
利用节点间的层次关系获取节点: 上一节讲了3中获取的方式: * ※※一、绝对获取,获取元素的3种方式:—Element * 1、getElementById(): 通过标签中的id属性值获来取该标签对象 * 2、getElementsByName(): 通过标签中的name属性值来获取该标签对象集合 * 3、getElementsByTagName(): 通过标签名来获取该标签对象集合
谙忆
2021/01/21
8560
JavaScript---网络编程(7)-Dom模型(节点间的层次关系,节点的增、删、改)
js|jq获取兄弟节点,父节点,子节点
08.19自我总结 js|jq获取兄弟节点,父节点,子节点 一.js var parent = test.parentNode; // 父节点 var chils = test.childNodes; // 全部子节点 var first = test.firstChild; // 第一个子节点 var last = test.lastChile; // 最后一个子节点  var previous = test.previousSibling; // 上一个兄弟节点 var next = test.next
小小咸鱼YwY
2019/09/11
15.2K0
折腾Google Docs API 的坑
快速开始 https://developers.google.cn/docs/api/quickstart/nodejs#step_2_install_the_client_library
拿我格子衫来
2022/01/24
1.3K0
折腾Google Docs API 的坑
使用Google翻译Api
将环境变量GOOGLE_APPLICATION_CREDENTIALS设置为包含服务帐户密钥的JSON文件的文件路径。在Linux或macOS系统中设置方法如下:
职场亮哥
2020/10/10
4.6K0
Facebook Ads广告业务API接口的源代码泄露漏洞
此前,我对“Windows NT” 和 “Windows Phone”模型有所研究,后来,我看到好多人参与了Facebook的漏洞赏金项目并收获了奖励,所以,我想那我也来试试吧,看看能不能入围Facebook的白帽致谢榜,想当年我也两次入围微软操作系统漏洞安全名人堂呢。
FB客服
2018/12/28
1.2K0
怎么解决google ads广告被拒登 存在恶意软件或垃圾软件的问题
2020年google adwords上线了最新的安全算法,针对客户网站存在恶意软件以及垃圾软件的情况,将会直接拒绝推广,显示已拒登:恶意软件或垃圾软件的提示。导致国内大部分做外贸以及google推广的客户受到影响,很多客户找到我们SINE安全公司寻求技术上的支持,帮忙解决问题,促使goole广告尽快上线。像这种问题该如何解决处理呢?
网站安全专家
2020/04/24
1.3K0
怎么解决google ads广告被拒登 存在恶意软件或垃圾软件的问题
ADS1115IDGSR
生产厂家:TEXAS INSTRUMENTS 型号参数:ADS1115IDGSR参数Brand NameTexas Instruments是否无铅不含铅是否Rohs认证符合生命周期ActiveIHS 制造商TEXAS INSTRUMENTS INC零件包装代码MSOP包装说明MSOP-10针数10Reach Compliance CodecompliantECCN代码EAR99HTS代码8542.39.00.01Factory Lead Time1 week风险等级1.21Samacsys Confide
电子交流圈
2022/03/20
6030
ADS振铃仿真
T=35um,表面导体厚度,1oz铜为35um,0.5oz铜为17um,此处设置为1oz;
黑马Amos
2023/03/21
1.1K0
ADS振铃仿真
js控制节点小结
DOM节点操作 <style> *{ margin: 0; padding: 0; } ul{ list-style: none; } a{ text-decoration: none; color: #333; } div{ margin-left
天天_哥
2018/09/29
5.9K0
js创建img节点
我们需要使用document对象的createElement方法创建了一个img元素:
IT工作者
2022/01/05
8.2K0
ADS1.2破解
ads1.2 license 1.拷贝{}内容到文本文档里面( 不包括{} ) 2.改成 .dat 3.按着向导导入即可 {PACKAGE ads armlmd 1.200 E32F0DE5161D COMPONENTS="armasm compiler \ bats armulate axd adwu fromelf armlink codewarrior armsd" INCREMENT ads armlmd 1.200 permanent uncounted 612C53EF47C7 \ HOSTID=ANY ISSUER="Full License by armer, only for educational purpose!" ck=0 }
TSINGEYE清眸物联
2023/01/04
6860
Js 类型转换
JavaScript 是一种弱类型或者说动态语言。这意味着你不用提前声明变量的类型,在程序运行过程中,类型会被自动确定。这也意味着你可以使用同一个变量保存不同类型的数据:
hss
2022/02/25
20.4K0
js时间转换
//时间戳格式化 //时间转换 function stamptime(time) { var date = new Date(time) var Y = date.g
阿超
2022/08/16
12.2K0

相似问题

Google Ads API -许可

11

Google Ads API集成

29

Google Ads API帐户丢失

10

更新预算Google Ads API

10

将Google-Ads API GoogleAdsRow转换为CSV?

114
添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
查看详情【社区公告】 技术创作特训营有奖征文