前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >pytorch常用数据操作函数

pytorch常用数据操作函数

原创
作者头像
苏十四
修改2021-02-20 18:08:13
8020
修改2021-02-20 18:08:13
举报
文章被收录于专栏:攀攀的专栏

博主的研究方向是目标检测,深度学习框架使用Pytorch,在日常的使用过程中经常会碰到一些问题,因此整理一下pytorch的一些常用接口和使用技巧。

1、形状变换:

tensor.reshape( ) tensor.view( )

两个方法都是改变张量的形状,区别在于,view只能处理连续存储的张量,reshape可以处理任何张量。如果张量本身是连续存储的,这两个方法便没有区别。

2、维度变换:

tensor.transpose( ) tensor.permute( )

transpose一次只能进行两个维度的交换,permute一次可以进行多个维度的交换。

3、张量拼接:

torch.cat( ) torch.stack( )

cat方法在拼接的时候,维度会保持不变,按指定的维度进行拼接。stack方法在拼接的时候,会增加一个维度,同时可以按照指定维度进行拼接。

代码语言:python
代码运行次数:0
复制
import torch

>> x = torch.arange(12).reshape(2, 6)
"""============ dim=0 =========="""
>> torch.stack([x, x], dim=0)
tensor([[[ 0,  1,  2,  3,  4,  5],
                  [ 6,  7,  8,  9, 10, 11]],
                [[ 0,  1,  2,  3,  4,  5],
                  [ 6,  7,  8,  9, 10, 11]]])

>> torch.cat([x, x], dim=0)
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11],
        [ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])

"""============ dim=1 =========="""
>> torch.stack([x, x], dim=1)
tensor([[[ 0,  1,  2,  3,  4,  5],
                  [ 0,  1,  2,  3,  4,  5]],
                [[ 6,  7,  8,  9, 10, 11],
                  [ 6,  7,  8,  9, 10, 11]]])

>> torch.cat([x, x], dim=1)
tensor([[ 0,  1,  2,  3,  4,  5,  0,  1,  2,  3,  4,  5],
                [ 6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11]])

"""============ dim=2 =========="""
>> torch.stack([x, x], dim=2)
tensor([[[ 0,  0],
                  [ 1,  1],
                  [ 2,  2],
                  [ 3,  3],
                  [ 4,  4],
                  [ 5,  5]],
                [[ 6,  6],
                  [ 7,  7],
                  [ 8,  8],
                  [ 9,  9],
                  [10, 10],
                  [11, 11]]])
>> torch.cat([x, x], dim=2)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

4、排序获取topk:

tensor.topk(k, dim, largest, sorted, out)

在维度dim=-1上,按照largest=True 方式进行排序,然后获取前k个排序的索引,如果sorted=True,则返回的是排序后的值。out为输出的张量。

代码语言:python
代码运行次数:0
复制
>> x = torch.randperm(20).reshape(2, -1)
tensor([[17,  5, 19,  6, 10, 12,  7, 18,  3,  2],
        [ 8,  1,  9, 14, 16,  4, 11, 13, 15,  0]])
>> x.topk(k=3)
torch.return_types.topk(
values=tensor([[19, 18, 17], [16, 15, 14]]),
indices=tensor([[2, 7, 0], [4, 8, 3]]))
 >> x.topk(k=3, sorted=False)
 torch.return_types.topk(
values=tensor([[19, 18, 17], [15, 16, 14]]),
indices=tensor([[2, 7, 0], [8, 4, 3]]))

5、获取随机整数排列:

torch.randperm(n, out)

返回一个数值范围从0到n-1的随机整数排列,长度为n。

代码语言:python
代码运行次数:0
复制
>> torch.randperm(10)
tensor([8, 0, 4, 2, 5, 9, 1, 6, 7, 3])

6、张量的切片方式*:

假设,现在有一个形状为(6, 4, 5, 7, 10)的张量,现在要根据一个随机生成的形状为(6, 4, 5, 7, 3)的索引张量,将最后一个维度的一部分数据提取出来:

代码语言:python
代码运行次数:0
复制
>> x = torch.rand(6, 4, 5, 7, 10)
>> keep = torch.round(torch.rand((6, 4, 5, 7, 3)) * 10)
>> keep = keep.to(torch.long)
>> batch_number = torch.arange(6).reshape(6, 1, 1, 1, 1)
>> channel_number = torch.arange(4).reshape(1, 4, 1, 1, 1)
>> seg_number = torch.arange(5).reshape(1, 1, 5, 1, 1)
>> word_number = torch.arange(7).reshape(1, 1, 1, 7, 1)
>> x[batch_number, channel_number, seg_number, word_number, keep]

7、广播机制:

在计算过程中使用广播机制,需要保持两个操作数张量的维数一致,且其中一个维度为1,其他维度和另一个张量保持一致。

代码语言:javascript
复制
>> x = torch.rand(2, 2, 2)
tensor([[[0.0043, 0.9672],
         [0.0853, 0.2951]],
        [[0.8648, 0.1455],
         [0.1597, 0.9439]]])
>> w1 = torch.rand(2, 2, 1)
>> x * w1 
tensor([[[0.0021, 0.4792],
         [0.0703, 0.2431]],
        [[0.4721, 0.0794],
         [0.1444, 0.8538]]])
>> w2 = torch.rand(2, 1, 2)
>> x * w2 
tensor([[[0.0033, 0.9109],
         [0.0654, 0.2779]],
        [[0.8263, 0.1080],
         [0.1526, 0.7007]]])
>> w3 = torch.rand(1, 2, 2)
>> x * w3
tensor([[[0.0022, 0.7960],
         [0.0775, 0.1908]],
        [[0.4384, 0.1198],
         [0.1450, 0.6102]]])
 >> w4 = torch.rand(2, 2, 4)
>> x * w4
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 2

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1、形状变换:
  • 2、维度变换:
  • 3、张量拼接:
  • 4、排序获取topk:
  • 5、获取随机整数排列:
  • 6、张量的切片方式*:
  • 7、广播机制:
相关产品与服务
对象存储
对象存储(Cloud Object Storage,COS)是由腾讯云推出的无目录层次结构、无数据格式限制,可容纳海量数据且支持 HTTP/HTTPS 协议访问的分布式存储服务。腾讯云 COS 的存储桶空间无容量上限,无需分区管理,适用于 CDN 数据分发、数据万象处理或大数据计算与分析的数据湖等多种场景。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档