首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >PyTorch入门笔记-堆叠stack函数

PyTorch入门笔记-堆叠stack函数

作者头像
触摸壹缕阳光
发布于 2021-02-26 08:24:22
发布于 2021-02-26 08:24:22
6.8K00
代码可运行
举报
运行总次数:0
代码可运行

堆叠

torch.cat(tensors, dim = 0) 函数拼接操作是在现有维度上合并数据,并不会创建新的维度。如果在合并数据时,希望创建一个新的维度,则需要使用 torch.stack 操作。

torch.stack(tensors, dim = 0) 函数可以使用堆叠的方式合并多个张量,参数 tensors 保存了所有需要合并张量的序列(任何Python的序列对象,比如列表、元组等),参数 dim 指定新维度插入的位置,torch.stack 函数中的 dim 参数与 torch.unsqueeze 函数(增加长度为 1 的新维度)中的 dim 参数用法一致:

  • dim ≥ 0 时,在 dim 之前插入新维度;
  • dim < 0 时,在 dim 之后插入新维度;

例如,对于形状为

[b, c, h, w]

的张量,在不同位置通过 torch.stack 操作插入新维度,dim 参数对应的插入位置设置如下图所示。

比如张量

A

是形状为

[3, 32, 32]

的 3 通道图片张量,张量

B

是另外一个形状为

[3, 32, 32]

的 3 通道图片张量。使用 torch.stack 合并这两个图片张量,批量维度插入在 dim = 0 的位置上,具体代码如下。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch

# 模拟图像张量A
a = torch.randn(3, 32, 32)
# 模拟图像张量B
b = torch.randn(3, 32, 32)

# 堆叠合并为2个图片张量,批量的维度插在最前面
stack_ab = torch.stack([a, b], dim = 0)
print(stack_ab.size())
# torch.Size([2, 3, 32, 32])

同样可以在其它位置上插入新的维度,比如,最末尾插入批量维度。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch

# 模拟图像张量A
a = torch.randn(3, 32, 32)
# 模拟图像张量B
b = torch.randn(3, 32, 32)

# 堆叠合并为2个图片张量,批量的维度插在最末尾
stack_ab = torch.stack([a, b], dim = -1)
print(stack_ab.size())
# torch.Size([3, 32, 32, 2])

torch.cat(tensors, dim = 0) 函数有两个约束:

  • 参数 tensors 中所有需要合并的张量必须是相同的数据类型;
  • 非合并维度的长度必须一致

显然 torch.cat 函数也能够拼接合并

A

B

两个图片张量。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch

# 模拟图像张量A
a = torch.randn(3, 32, 32)
# 模拟图像张量B
b = torch.randn(3, 32, 32)

# 拼接合并为2个图片张量,没有批量维度的概念
cat_ab = torch.cat([a, b], dim = 0)
print(cat_ab.size())
# torch.Size([6, 32, 32])

形状都是

[3, 32, 32]

A

B

两个图片张量,沿着第 0 个维度进行合并(通道维度)后的张量形状为

[6, 32, 32]

。虽然 torch.cat 函数能够顺利的拼接合并,但是在理解时,需要按照前 3 个通道来自第一张图片,后 3 个通道来自第二张图片的方式理解张量。对于这个例子,明显通过 torch.stack 的方式创建新维度的方式更为合理,得到的形状为

[2, 3, 32, 32]

的张量也更容易理解。

torch.stack(tensors, dim = 0) 使用个 torch.cat 函数一样同样需要一些约束,这也是在使用 torch.stack(tensors, dim = 0) 函数时需要注意的地方。

  • 参数 tensors 中所有需要合并的张量必须是相同的数据类型
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch

# 模拟图像张量A
a = torch.randint(0, 255, (3, 32, 32))
# 模拟图像张量B
b = torch.randn(3, 32, 32)

print(a.dtype)
# torch.int64

print(b.dtype)
# torch.float32

# 非法堆叠,张量AB的数据类型不相同
stack_ab = torch.stack([a, b], dim = 0)
print(stack_ab.size())
# Traceback (most recent call last):
#   File "/home/chenkc/code/pytorch/test01.py", line 12, in <module>
#     stack_ab = torch.stack([a, b], dim = 0)
# RuntimeError: Expected object of scalar type long int but got scalar type float for sequence element 1.
  • 所有待合并的张量形状必须完全一致

torch.stack 也需要满足张量堆叠合并的条件,它需要所有待合并的张量形状完全一致才可以进行合并。如果待合并张量的形状不一致时,进行堆叠合并会发生错误。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
import torch

# 模拟图像张量A
a = torch.randn(3, 32, 32)
# 模拟图像张量B
b = torch.randn(1, 32, 32)

# 非法堆叠操作,张量的形状不相同
stack_ab = torch.stack([a, b], dim = 0)
print(stack_ab.size())
# Traceback (most recent call last):
#   File "/home/chenkc/code/pytorch/test01.py", line 9, in <module>
#     stack_ab = torch.stack([a, b], dim = 0)
# RuntimeError: stack expects each tensor to be equal size, but got [3, 32, 32] at entry 0 and [1, 32, 32] at entry 1

References:

  1. 《TensorFlow深度学习
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2021-02-02,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI机器学习与深度学习算法 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验