当通过增加维度操作插入新维度后,可能希望在新维度上面复制若干份数据,满足后续算法的格式要求。考虑 Y = X@W + b 的例子,偏置 b 插入样本数的新维度后,需要在新维度上复制 Batch Size 份数据,将 shape 变为与 X@W 一致后,才能完成张量相加运算。
PyTorch 中常用于张量数据复制操作有 expand 和 repeat。「expand 和 repeat 两个函数只有 input.expand(\*sizes)
和 input.repeat(\*size)
一种定义方式。」 本小节主要介绍 input.expand(\*sizes)
input.expand(*sizes)
函数能够实现 input 输入张量中单维度(singleton dimension)上数据的复制操作,「其中 *sizes 分别指定了每个维度上复制的倍数,对于不需要(或非单维度)进行复制的维度,对应位置上可以写上原始维度的大小或者直接写 -1。」
“将张量中大小为 1 的维度称为单维度。比如形状为 [2, 3] 的张量就没有单维度,形状为 [1, 3] 中第 0 个维度上的大小为 1,因此第 0 个维度为张量的单维度。”以形状为 [2, 4] 的输入张量,输出为 3 个节点线性变换层为例,偏置 b 被定义为:
为了让偏置 b 具有单维度,需要通过 torch.unsqueeze(b, dim = 0)
插入新维度,变成矩阵:
此时张量 B 的形状为 [1, 3],我们需要在 dim = 0 批量维度上根据输入样本的数量复制若干次,由于输入的样本个数为 2(batch_size = 2),即复制一份,变成:
通过 b.expand([2, -1])
(或者b.expand(2, 3))即可在 dim = 0 维度复制 1 次,在 dim = 1 维度不复制。具体实现如下:
import torch
# 创建偏置b
b = torch.tensor([1, 2, 3])
# 为张量b插入新的维度
B = torch.unsqueeze(b, 0)
print(B.size())
# torch.Size([1, 3])
print(B)
# tensor([[1, 2, 3]])
在批量维度上复制数据 1 份,实现如下:
# -1意味着不改变对应维度的大小
B = B.expand([2, -1])
print(B)
# tensor([[1, 2, 3],
# [1, 2, 3]])
此时 B 的shape 变为 2,3,可以直接与 X@W 进行相加运算。通过 torch.unsqueeze(b, dim = 0)
为偏置 b 插入了一个批量维度,此时偏置 b 变成了形状为 [1, 3] 的 2D 张量 B,正是因为有了单维度才能够对 2D 张量 B 的第 0 个维度进行复制操作,因此只要张量中有单维度,就可以通过 expand 函数对相应的单维度进行复制操作。
import torch
A = torch.arange(12).view(3, 1, 4)
print(A.size())
# Size([3, 1, 4])
A = A.expand([3, 12, -1])
print(A.size())
# torch.Size([3, 12, 4])
在深度学习中插入批量维度并进行复制操作的场景非常多(比如偏置 b),简单来说就是为输入张量添加一个批量维度并在批量维度上复制输入张量多份。比如复制 10 份形状为 [28, 28, 3] 的图片张量,最后图片张量的形状为 [10, 28, 28, 3]。
「expand 函数中融合了插入批量维度并在新插入的批量维度上复制数据的操作。」 使用起来非常简单。
import torch
# 使用[0, 1)均匀分布模拟图片张量
# (channels, height, width)
img = torch.randn([28, 28, 3])
# 现在需要将图片复制4份
imgs = img.expand([4, -1, -1, -1])
print(imgs.size())
# torch.Size([4, 28, 28, 3])
对于上面的偏置 b,我们可以省略 torch.unsqueeze(b, dim = 0)
插入批量维度的操作,直接使用 expand 函数。
import torch
# 创建偏置b
b = torch.tensor([1, 2, 3])
# 直接插入批量维度并复制2份
B = b.expand([2, -1])
print(B.size())
# torch.Size([2, 3])
print(B)
# tensor([[1, 2, 3],
# [1, 2, 3]])
「还有一个需要注意:expand 函数并不会重新分配内存,返回的结果仅仅是原始张量上的一个视图。」
References: 1. 《TensorFlow深度学习》
本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!