前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >Pytorch常用张量操作以及归一化算法实现

Pytorch常用张量操作以及归一化算法实现

作者头像
算法工程师之路
发布2019-08-05 20:19:42
8.6K0
发布2019-08-05 20:19:42
举报
文章被收录于专栏:算法工程师之路

本篇文章将要总结下Pytorch常用的一些张量操作,并说明其作用,接着使用这些操作实现归一化操作的算法,如BN,GN,LN,IN等!

1

Pytorch中常用张量操作

torch.cat

对数据沿着某一维度进行拼接,cat后的总维度数不变,需要注意两个张量进行cat时某一维的维数要相同,否则会报错!

代码语言:javascript
复制
import torch
x = torch.randn(2,3)
y = torch.randn(1,3)
torch.cat((x, y), 0)   # 维度为(3, 3)
z = torch.randn(1, 4)
torch.cat((x, z), 0)  # 报错

stack

相比于Cat,Stack则会增加新的维度,并且将两个矩阵在新的维度上进行堆叠,一般要求两个矩阵的维度是相同的!

代码语言:javascript
复制
import torch
x = torch.randn(1,2)
y = torch.randn(1,2)
torch.stack((x, y), 0)   # 在0维度进行堆叠,维度为(2, 1, 2)
torch.stack((x, y), 1)   # 维度为(1, 2, 2)

transpose

其作用为交换两个维度,类似于二维矩阵的转置作用!

代码语言:javascript
复制
import torch
x = torch.randn(2,3)
x.transpose(0, 1)  # 维度为(3, 2)

permute

其相当于增强版的transpose,适合于多维数据,更加灵活一点!

代码语言:javascript
复制
import torch
x = torch.randn(1,2,3,4)
x_p = x.permute(1,0,2,3)  # 维度变为(2,1,3,4)

squeeze和unsqueeze

squeeze(dim)为压缩的意思,即去掉维度数为1的dim,默认是去掉所有为1的,当然也可以自己指定,但如果指定的维度数不为1,则不会发生任何改变。unsqueeze(dim)则与squeeze(dim)正好相反,为添加一个维度的作用。

代码语言:javascript
复制
import torch
x = torch.randn(2,1)
x.squeeze()   # 维度(2,)
x.squeeze(1)  # 维度(2,)
x.unsqueeze(2) # 维度(2,1,1)
x.unsqueeze(0) # 维度(1,2,1)

view、contigous和reshape

有些tensor并不是占用一整块内存,而是由不同的数据块组成,而tensor的view()操作依赖于内存是整块的,这时只需要执行contiguous()这个函数,把tensor变成在内存中连续分布的形式。特别的在Pytorch0.4中,在使用了permute和transpose后,内存就不连续了,因此不能直接使用view函数,应该先contigous()变成连续内存后,再使用view。 Pytorch0.4中,增加了一个reshape函数,就相当于contigous().view()的功能了!

2

归一化操作的实现

我们今天只来考虑如何实现,至于归一化的原理我们就不再赘述,知乎和博客都写的很多了,对于这几种归一化的方法,比如BN(Batch),LN(Layer),IN(Instance),GN(Group)这四种,在GN的论文中有一幅图可以清晰的描述,我们不用看公式,只要把下面这个图记住就好了!(蓝色区域即为其归一化的区域,说白了我们每个归一化时使用的均值和方差就是由蓝色区域计算得来的,然后作用到这个蓝色区域进行归一化,从而对整体X进行归一化)。

那么我们可以看下简单实现(仅归一化)

Batch Normalization

代码语言:javascript
复制
import torch
from torch import nn
bn = nn.BatchNorm2d(num_features=3, eps=0, affine=False, track_running_stats=False)
x = torch.rand(10, 3, 5, 5)*10000
official_bn = bn(x)   # 官方代码

x1 = x.permute(1, 0, 2, 3).reshape(3, -1) # 对(N, H, W)计算均值方差
mean = x1.mean(dim=1).reshape(1, 3, 1, 1)
# x1.mean(dim=1)后维度为(3,)
std = x1.std(dim=1, unbiased=False).reshape(1, 3, 1, 1)
my_bn = (x - mean)/std
print((official_bn-my_bn).sum())  # 输出误差

Layer Normalization

代码语言:javascript
复制
import torch
from torch import nn
ln = nn.LayerNorm(normalized_shape=[3, 5, 5], eps=0, elementwise_affine=False)
x = torch.rand(10, 3, 5, 5)*10000
official_ln = ln(x)   # 官方代码

x1 = x.reshape(10, -1)  # 对(C,H,W)计算均值方差
mean = x1.mean(dim=1).reshape(10, 1, 1, 1)
std = x1.std(dim=1, unbiased=False).reshape(10, 1, 1, 1)
my_ln = (x - mean)/std
print((official_ln-my_ln).sum())

Instance Normalization

代码语言:javascript
复制
import torch
from torch import nn
In = nn.InstanceNorm2d(num_features=3, eps=0, affine=False, track_running_stats=False)
x = torch.rand(10, 3, 5, 5)*10000
official_In = In(x)   # 官方代码

x1 = x.reshape(30, -1)  # 对(H,W)计算均值方差
mean = x1.mean(dim=1).reshape(10, 3, 1, 1)
std = x1.std(dim=1, unbiased=False).reshape(10, 3, 1, 1)
my_In = (x - mean)/std
print((official_In-my_In).sum())

Group Normalization

代码语言:javascript
复制
import torch
from torch import nn
gn = nn.GroupNorm(num_groups=4, num_channels=20, eps=0, affine=False)
# 分成了4组,也就是说蓝色区域为(5,5, 5)
x = torch.rand(10, 20, 5, 5)*10000
official_gn = gn(x)   # 官方代码

x1 = x.reshape(10,4,-1)  # 对(H,W)计算均值方差
mean = x1.mean(dim=2).reshape(10, 4, -1)
std = x1.std(dim=2, unbiased=False).reshape(10, 4, -1)
my_gn = ((x1 - mean)/std).reshape(10, 20, 5, 5)
print((official_gn-my_gn).sum())

以上代码参考并修改自知乎专栏文章(https://zhuanlan.zhihu.com/p/69659844)

3

资源分享

欢迎关注我的个人公众号 (算法工程师之路),回复"左神算法基础CPP"即可获得大量算法源码,并实时更新!希望大家多多支持哦~

公众号简介:分享算法工程师必备技能,谈谈那些有深度有意思的算法,主要范围:C++数据结构与算法/深度学习(CV),立志成为Offer收割机!坚持分享算法题目和解题思路(Day By Day)

更多精彩推荐,请关注我们

长按二维码关注算法工程师之路

算法工程师

我们一起努力,For World!

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2019-08-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 算法工程师之路 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 欢迎关注我的个人公众号 (算法工程师之路),回复"左神算法基础CPP"即可获得大量算法源码,并实时更新!希望大家多多支持哦~
相关产品与服务
批量计算
批量计算(BatchCompute,Batch)是为有大数据计算业务的企业、科研单位等提供高性价比且易用的计算服务。批量计算 Batch 可以根据用户提供的批处理规模,智能地管理作业和调动其所需的最佳资源。有了 Batch 的帮助,您可以将精力集中在如何分析和处理数据结果上。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档