前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >PyTorch入门笔记-复制数据repeat函数

PyTorch入门笔记-复制数据repeat函数

作者头像
触摸壹缕阳光
修改2021-02-12 01:28:04
5.9K0
修改2021-02-12 01:28:04
举报
文章被收录于专栏:AI机器学习与深度学习算法

repeat

前面提到过 input.expand(*sizes) 函数能够实现 input 输入张量中单维度(singleton dimension)上数据的复制操作。「对于非单维度上的复制操作,expand 函数就无能为力了,此时就需要使用 input.repeat(*sizes)。」

input.repeat(*sizes) 可以对 input 输入张量中的单维度和非单维度进行复制操作,并且会真正的复制数据保存到内存中。input.expand(*sizes)input.repeat(*sizes) 两个函数的区别如下表所示。

input.repeat(*sizes) 函数中的 *sizes 参数分别指定了各个维度上复制的倍数,对于不需要复制的维度需要指定为 1。(在expand函数中对于不需要(或非单维度)进行复制的维度,对应位置上可以写上原始维度的大小或者直接写 -1)

对单维度上的数据进行复制,repeat 函数和 expand 函数类似,和 expand 函数一样,repeat 函数也融合了插入批量维度并在新插入的批量维度上复制数据的操作。

代码语言:txt
复制
import torch

# 创建偏置b
b = torch.tensor([1, 2, 3])
# 为张量b插入新的维度
B = torch.unsqueeze(b, 0)

print(B.size())
# torch.Size([1, 3])

print(B)
# tensor([[1, 2, 3]])

在批量维度上复制数据 1 份,实现如下:

代码语言:txt
复制
# 1意味着不对对应维度进行复制
B = B.repeat([2, 1])
print(B)
# tensor([[1, 2, 3],
#         [1, 2, 3]])

由于 repeat 函数也融合了插入批量维度并在新插入的批量维度上复制数据的操作,所以对于上面的偏置 b,我们可以省略 torch.unsqueeze(b, dim = 0) 插入批量维度的操作,直接使用 repeat 函数。

代码语言:txt
复制
import torch

# 创建偏置b
b = torch.tensor([1, 2, 3])
# 直接插入批量维度并复制2份
B = b.repeat([2, 1])

print(B.size())
# torch.Size([2, 3])

print(B)
# tensor([[1, 2, 3],
#         [1, 2, 3]])

「使用 repeat 函数对非单维度进行复制,简单来说就是对非单维度的所有元素整体进行复制。」 以下面形状为 (2, 2) 的 2D 张量为例。

  • Step1: 将 dim = 0 维度上的数据复制 1 份,dim = 1 维度上的数据保持不变。
  • Step2: Step1 得到的形状为 (4, 2) 的 2D 张量的 dim = 0 维度上的数据保持不变,dim = 1 维度上的数据复制 1 份。

上面操作使用 repeat 函数的具体实现如下。

代码语言:txt
复制
import torch

a = torch.arange(4).reshape([2, 2])
print(a)
# tensor([[0, 1],
#         [2, 3]]) 


# dim=0维度的数据复制1份,dim=1维度的数据保持不变
step1_a = a.repeat([2, 1])
print(step1_a)
# tensor([[0, 1],
#         [2, 3],
#         [0, 1],
#         [2, 3]])


# 将dim=0维度的数据保持不变,dim=1维度的数据复制1份
step2_a = step1_a.repeat([1, 2])
print(step2_a)
# tensor([[0, 1, 0, 1],
#         [2, 3, 2, 3],
#         [0, 1, 0, 1],
#         [2, 3, 2, 3]])
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-01-18,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • repeat
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档