拼接
在 PyTorch 中,可以通过 torch.cat(tensors, dim = 0) 函数拼接张量,其中参数 tensor 保存了所有需要合并张量的序列(任何Python的序列对象,比如列表、...元组等),dim 参数指定了需要合并的维度索引。...以包含批量维度的图像张量为例,设张量 A 保存了 4 张,长和宽为 32 的三通道像素矩阵,则张量 A 的形状为 [4, 3, 32, 32](PyTorch将通道维度放在前面,即 (batch_size...现在需要在批量维度上合并两个包含批量维度的图像张量,这里批量维度索引号为 0,即 dim = 0,合并张量 A 和 B 的代码如下:
import torch
# 模拟图像张量A
a = torch.randn...(4, 3, 32, 32)
# 模拟图像张量B
b = torch.randn(5, 3, 32, 32)
# 在批量维度上合并张量A和B
cat_ab = torch.cat([a, b], dim