博主的研究方向是目标检测,深度学习框架使用Pytorch,在日常的使用过程中经常会碰到一些问题,因此整理一下pytorch的一些常用接口和使用技巧。
tensor.reshape( ) tensor.view( )
两个方法都是改变张量的形状,区别在于,view只能处理连续存储的张量,reshape可以处理任何张量。如果张量本身是连续存储的,这两个方法便没有区别。
tensor.transpose( ) tensor.permute( )
transpose一次只能进行两个维度的交换,permute一次可以进行多个维度的交换。
torch.cat( ) torch.stack( )
cat方法在拼接的时候,维度会保持不变,按指定的维度进行拼接。stack方法在拼接的时候,会增加一个维度,同时可以按照指定维度进行拼接。
import torch
>> x = torch.arange(12).reshape(2, 6)
"""============ dim=0 =========="""
>> torch.stack([x, x], dim=0)
tensor([[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]],
[[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]]])
>> torch.cat([x, x], dim=0)
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11]])
"""============ dim=1 =========="""
>> torch.stack([x, x], dim=1)
tensor([[[ 0, 1, 2, 3, 4, 5],
[ 0, 1, 2, 3, 4, 5]],
[[ 6, 7, 8, 9, 10, 11],
[ 6, 7, 8, 9, 10, 11]]])
>> torch.cat([x, x], dim=1)
tensor([[ 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11]])
"""============ dim=2 =========="""
>> torch.stack([x, x], dim=2)
tensor([[[ 0, 0],
[ 1, 1],
[ 2, 2],
[ 3, 3],
[ 4, 4],
[ 5, 5]],
[[ 6, 6],
[ 7, 7],
[ 8, 8],
[ 9, 9],
[10, 10],
[11, 11]]])
>> torch.cat([x, x], dim=2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
tensor.topk(k, dim, largest, sorted, out)
在维度dim=-1上,按照largest=True 方式进行排序,然后获取前k个排序的索引,如果sorted=True,则返回的是排序后的值。out为输出的张量。
>> x = torch.randperm(20).reshape(2, -1)
tensor([[17, 5, 19, 6, 10, 12, 7, 18, 3, 2],
[ 8, 1, 9, 14, 16, 4, 11, 13, 15, 0]])
>> x.topk(k=3)
torch.return_types.topk(
values=tensor([[19, 18, 17], [16, 15, 14]]),
indices=tensor([[2, 7, 0], [4, 8, 3]]))
>> x.topk(k=3, sorted=False)
torch.return_types.topk(
values=tensor([[19, 18, 17], [15, 16, 14]]),
indices=tensor([[2, 7, 0], [8, 4, 3]]))
torch.randperm(n, out)
返回一个数值范围从0到n-1的随机整数排列,长度为n。
>> torch.randperm(10)
tensor([8, 0, 4, 2, 5, 9, 1, 6, 7, 3])
假设,现在有一个形状为(6, 4, 5, 7, 10)的张量,现在要根据一个随机生成的形状为(6, 4, 5, 7, 3)的索引张量,将最后一个维度的一部分数据提取出来:
>> x = torch.rand(6, 4, 5, 7, 10)
>> keep = torch.round(torch.rand((6, 4, 5, 7, 3)) * 10)
>> keep = keep.to(torch.long)
>> batch_number = torch.arange(6).reshape(6, 1, 1, 1, 1)
>> channel_number = torch.arange(4).reshape(1, 4, 1, 1, 1)
>> seg_number = torch.arange(5).reshape(1, 1, 5, 1, 1)
>> word_number = torch.arange(7).reshape(1, 1, 1, 7, 1)
>> x[batch_number, channel_number, seg_number, word_number, keep]
在计算过程中使用广播机制,需要保持两个操作数张量的维数一致,且其中一个维度为1,其他维度和另一个张量保持一致。
>> x = torch.rand(2, 2, 2)
tensor([[[0.0043, 0.9672],
[0.0853, 0.2951]],
[[0.8648, 0.1455],
[0.1597, 0.9439]]])
>> w1 = torch.rand(2, 2, 1)
>> x * w1
tensor([[[0.0021, 0.4792],
[0.0703, 0.2431]],
[[0.4721, 0.0794],
[0.1444, 0.8538]]])
>> w2 = torch.rand(2, 1, 2)
>> x * w2
tensor([[[0.0033, 0.9109],
[0.0654, 0.2779]],
[[0.8263, 0.1080],
[0.1526, 0.7007]]])
>> w3 = torch.rand(1, 2, 2)
>> x * w3
tensor([[[0.0022, 0.7960],
[0.0775, 0.1908]],
[[0.4384, 0.1198],
[0.1450, 0.6102]]])
>> w4 = torch.rand(2, 2, 4)
>> x * w4
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 2
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。