首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >Pytorch:如何找到2D张量每行中第一个非零元素的索引?

Pytorch:如何找到2D张量每行中第一个非零元素的索引?
EN

Stack Overflow用户
提问于 2019-05-11 15:31:36
回答 3查看 3.2K关注 0票数 7

我有一个2D张量,每一行都有一些非零元素,如下所示:

代码语言:javascript
运行
复制
import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)

我想要一个包含每行第一个非零元素索引的张量:

代码语言:javascript
运行
复制
indices = tensor([2],
                 [3])

我如何在Pytorch中计算它?

EN

回答 3

Stack Overflow用户

回答已采纳

发布于 2019-05-12 22:00:35

我可以为我的问题找到一个棘手的答案:

代码语言:javascript
运行
复制
  tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                     [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
  idx = reversed(torch.Tensor(range(1,8)))
  print(idx)

  tmp2= torch.einsum("ab,b->ab", (tmp, idx))

  print(tmp2)

  indices = torch.argmax(tmp2, 1, keepdim=True)
  print(indeces)

结果是:

代码语言:javascript
运行
复制
tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
       [0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
        [3]])
票数 5
EN

Stack Overflow用户

发布于 2020-02-13 15:46:35

我简化了Iman的方法来做以下事情:

代码语言:javascript
运行
复制
idx = torch.arange(tmp.shape[1], 0, -1)
tmp2= tmp * idx
indices = torch.argmax(tmp2, 1, keepdim=True)
票数 11
EN

Stack Overflow用户

发布于 2021-11-15 07:59:44

所有非零值都相等,因此argmax返回第一个索引。

代码语言:javascript
运行
复制
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]])
indices = tmp.argmax(1)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56088189

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档