上节课讲了pytorch中的0维和1维的标量定义和代码表达形式。
其中一维标量主要用于Bias(偏差)中,如在构建神经元中多组数据导入到一个神经元中,由激活函数激活输出一个数值,则该神经元主要使用bias功能。线性层输入(Linear input)中也有用到一维标量,如一个28*28像素点的图片矩阵数据打平到一个784的一维标量中,便会用到Bias功能。
本节课将介绍pytorch中的二维和三维表达形式。
首先引入pytorch包
import torch
先构建一个2维的2行3列的呈高斯分布的矩阵
a = torch.randn(2, 3)
print(a)
输出矩阵为
tensor([[ 0.5747, -0.7920, 0.7112],
[-0.5065, 1.1655, 0.7981]])
查看a的shape为
print(a.shape)
torch.Size([2, 3])
查看a的size为
print(a.size())
torch.Size([2, 3])
但若想分别查看a在0维、1维上的size为:
print(a.size(0))
print(a.size(1))
输出分别为
2
3
在pytorch图像识别中,一般用第一个元素表示第几张照片,第二个元素用来表示具体的数据内容。
在3维的标量中
表示方法为
b = torch.randn(1, 2, 3)
print(b)
输出b为
tensor([[[-0.3856, -0.3062, 1.1586],
[-1.9733, 0.2003, -0.0352]]])
输出b的shape为
print(b.shape)
torch.Size([1, 2, 3])
对b的第1个元素进行索引,并输出第1个元素的shape
print(b[0])
print(b[0].shape)
输出为
tensor([[ 0.4352, 1.2551, -2.0541],
[-0.2046, -0.3027, 0.0742]])
torch.Size([2, 3])
由此可见,第一个元素对应的为[2, 3]
三维应用最多的是在RNN、input和Batch中。假设一句话有10个单词,每个单词用00维的one_hot向量来编码,矩阵为[10, 100],([10 word, 100 feature]),一次想送入多句话时(假设29句话),则应把batch插入到中间,即此时矩阵为[10, 20, 100]。、
4维向量
构建时用
e = torch.randn(2, 3, 28, 28)
print(e)
输出为
tensor([[[[ 4.6746e-02, -7.2555e-01, -1.0163e+00, ..., 6.3349e-01,
-6.9263e-01, 2.3174e+00],
[ 2.6572e-01, 2.0966e+00, 1.9693e-01, ..., -3.8246e-01,
5.1631e-01, 5.2878e-01],
[ 5.1315e-01, 9.7876e-01, -1.0334e-01, ..., -8.2572e-01,
4.0260e-01, 9.2650e-01],
...,
[-6.3234e-01, -6.4486e-02, 8.9819e-02, ..., -4.8805e-02,
-3.6226e-01, 2.8216e-01],
[ 8.8367e-01, 2.0588e-01, -1.4349e+00, ..., 3.3854e-01,
-9.6813e-01, 5.8934e-01],
[ 6.1359e-01, 6.7494e-01, -1.5057e+00, ..., -2.0673e+00,
1.5338e+00, 7.4172e-01]],
[[-5.7504e-01, -2.1989e-01, 9.9902e-01, ..., 1.4300e+00,
4.7485e-01, 9.6381e-01],
[ 2.2733e-01, -7.8228e-01, -4.9039e-01, ..., 3.5477e-01,
5.8703e-01, -3.5601e-01],
[-4.5116e-01, 6.0935e-01, -7.8103e-01, ..., -3.0234e-01,
-1.2936e+00, 1.0186e+00],
...,
[-1.8746e-01, -2.3330e-01, 8.8092e-01, ..., -4.8307e-01,
-7.2092e-01, -1.0095e-01],
[-1.2032e+00, 4.1019e-01, 8.7527e-01, ..., 3.5443e-01,
3.3754e+00, -3.9624e-01],
[ 8.8196e-02, -4.1996e-02, -6.7734e-01, ..., 1.6720e+00,
-5.3305e-01, -4.0128e-01]],
[[ 1.3417e+00, -8.7618e-01, 1.1194e-01, ..., -1.8591e-02,
1.2383e+00, 1.9580e+00],
[ 2.6057e+00, 1.7841e+00, -2.6864e+00, ..., 1.9403e-01,
-4.5544e-01, -7.0843e-01],
[ 4.1448e-01, -2.3968e+00, -9.8667e-01, ..., -6.6330e-02,
-6.7622e-01, 1.3603e+00],
...,
[ 1.2888e+00, 5.7021e-01, 1.7715e+00, ..., -1.0324e+00,
-5.5282e-01, 6.7588e-01],
[ 6.2775e-01, 1.1220e+00, 7.7026e-01, ..., -1.1497e+00,
-6.9084e-01, 4.5506e-01],
[-1.0029e+00, 1.7508e+00, 1.0402e+00, ..., -1.1906e-01,
-1.4788e+00, -2.3655e-01]]],
[[[-1.2357e+00, -7.7063e-01, 4.3775e-01, ..., 1.3476e+00,
4.6510e-01, 6.0015e-01],
[-7.4814e-01, 8.0978e-01, -5.8995e-01, ..., 4.0269e-03,
1.1107e+00, 1.3944e+00],
[ 5.9686e-03, -1.3973e+00, -4.2270e-01, ..., -1.3700e+00,
3.5190e-01, -8.1495e-01],
...,
[ 1.0281e-03, 9.5155e-02, -1.2899e+00, ..., 1.0035e+00,
6.9070e-01, 1.6811e+00],
[-1.2639e+00, -1.7418e+00, -6.9545e-01, ..., 3.5612e-01,
-4.6358e-01, 9.3933e-01],
[ 1.2552e+00, 8.6772e-01, -8.9920e-02, ..., -7.6013e-01,
-1.7622e+00, 1.2953e-01]],
[[-5.1453e-01, -8.9602e-01, -5.4858e-01, ..., 1.5385e+00,
1.6776e+00, -2.5184e-01],
[ 8.0618e-01, -1.2494e+00, -1.0520e+00, ..., 2.0146e+00,
-9.5918e-01, -1.2805e+00],
[ 7.9278e-03, -1.4394e-01, 1.1932e-01, ..., 9.7892e-01,
1.5099e-01, -1.0838e+00],
...,
[-4.6785e-01, -1.4917e+00, 3.5120e-02, ..., -4.6422e-01,
-9.9104e-02, -8.8785e-01],
[ 9.1863e-01, 1.1168e+00, -1.0764e+00, ..., 1.6113e+00,
-1.3860e+00, 5.3114e-02],
[-1.7069e-01, 2.7445e-01, -5.0318e-02, ..., 9.5631e-02,
6.2200e-01, 3.2184e-01]],
[[ 9.4690e-01, 2.0903e+00, 1.5828e+00, ..., -1.1553e+00,
1.3350e+00, -1.4522e+00],
[-1.0217e+00, 3.4814e-01, -6.1494e-02, ..., -6.8919e-01,
-1.9823e-01, -2.6472e-01],
[-2.1480e+00, 5.6693e-01, -3.5862e-01, ..., 1.8241e+00,
1.7414e+00, -6.1781e-01],
...,
[ 9.0595e-01, 8.2399e-01, -3.3188e-01, ..., -1.4197e-01,
-3.1074e-01, 5.8298e-01],
[-8.2013e-01, -6.1951e-01, 4.9327e-01, ..., -1.0709e+00,
1.8096e-01, 2.8510e-01],
[-6.1442e-01, 9.5634e-01, -1.0928e+00, ..., 4.4558e-01,
-9.9419e-01, 1.0345e+00]]]])
此时再输出e的shape
输出为
print(e.shape)
torch.Size([2, 3, 28, 28])
这里的[2, 3, 28, 28]分别表示[2张照片,一张彩色图片的R G B三通道,mnist数据集图片的长,mnist数据集图片的宽 ]。
因此4维数据直接对应于现实生活中的图片。这种4维表示形式特别适合于表达CNN(卷积神经网络)。
补充知识:
我们还可以表达出向量的具体大小
如
print(e.numel())
输出为
4704
即为2*3*28*28的数值相乘得来
另外除了使用len(e.shape)输出e的维度外,还可以使用e.dim()来输出
print(e.dim())
print(len(e.shape))
结果均为
4
4
下节将介绍如何创建Tensor的数据。
看后感觉还可以,赞赏一下可以吗?写这些东西每天都要牵扯大量的精力
本文分享自 python pytorch AI机器学习实践 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!