Tensor.scatter_
(dim, index, src, reduce=None) → Tensor Parameters
src
. When empty, the operation returns self
unchanged.'add'
or 'multiply'
. 用的相对较少。直接看例子,
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
# 从这个例子出发来简单说明:首先dim=0,意味着需要沿着axis=0的方向进行操作,即index每一列逐渐增大,按列找到对应的索引号,然后按顺序把src中的元素填进去。
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
# dim=1, 按行找到对应的index,按顺序把src中的元素填进去
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),1.23)
tensor([[2.0000, 2.0000, 1.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 1.2300]])
# dim=1, 按行找到对应的index,按顺序把src中的元素填进去,不用管原来的位置是什么数字。
**注意:**index可以不用满,src按顺序填充。
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.0000*1.23, 2.0000],
[2.0000, 2.0000, 2.0000, 2.000*1.23]])
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
[2.0000, 2.0000, 2.0000, 2.4600]])
# dim=1, 按行找到对应的index,按顺序把src中的元素乘上去
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='add')
tensor([[2.0000, 2.0000, 2.0000+1.23, 2.0000],
[2.0000, 2.0000, 2.0000, 2.000+1.23]])
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 3.2300]])
# dim=1, 按行找到对应的index,按顺序把src中的元素加上去上去
tps://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html)
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有