torch.Tensor.scatter_()
是torch.gather()
函数的方向反向操作。两个函数可以看成一对兄弟函数。gather
用来解码one hot,scatter_
用来编码one hot。
scatter_(dim, index, src) → Tensor
dim (python:int)
– 用来寻址的坐标轴
index (LongTensor)
– 索引
src(Tensor)
–用来scatter的源张量,以防value未被指定。
value(python:float)
– 用来scatter的源张量,以防src未被指定。
现在我们来看看具体这么用,看下面这个例子就一目了然了。
dim =0import torch
x = torch.tensor([[0.9413, 0.9476, 0.1104, 0.9898, 0.6443],
[0.6913, 0.8924, 0.7530, 0.8874, 0.0557]])
result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
result.scatter_(0, indices, x)
输出为
tensor([[0.9413, 0.8924, 0.7530, 0.9898, 0.6443],
[0.0000, 0.9476, 0.0000, 0.8874, 0.0000],
[0.6913, 0.0000, 0.1104, 0.0000, 0.0557]])
dim = 0的情形:
比如上例中,dim=0,所以根据这个规则来self[index[i][j]][j] = src[i][j]来确定替换规则。
index中的值决定了src中的值在result中的放置位置。
dim=0时,则将列固定起来,先看第0列:
对于第0行,首先找到x的第0列第0行的值为0.9413,然后在用index[0][0]的值来找将要在result中放置的位置。
在这个例子中,index[0][0]=0, 所以0.9413将放置在result[0][0]这个位置。
对于result中的各项,他们的寻址过程如下:
x[0][1] = 0.9476 -> indices[0][1]=1 -> result[ index = 1 ][1] = 0.9476
x[1][3] = 0.8874 -> indices[1][3]=1 -> result[ index = 1 ][3] = 0.8874
依此类推。
以下为dim = 1的情形:x[0][0] = 0.9413 -> indices[0][0]=0 -> result[0][index = 0] = 0.9413
x[0][3] = 0.9898 -> indices[0][3]=0 -> result[0][index = 0] = 0.9898 ## 将上一步的值覆盖了
x[0][4] = 0.6443 -> indices[0][4]=0 -> result[0][index = 0] = 0.6443 ## 再次将上一步的值覆盖了
因此result[0][0]的值为0.6443.
dim = 1
x = torch.tensor([[0.9413, 0.9476, 0.1104, 0.9898, 0.6443],
[0.6913, 0.8924, 0.7530, 0.8874, 0.0557]])
result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0],
[2, 0, 0, 1, 2]])
result.scatter_(1, indices, x)
输出为
tensor([[0.6443, 0.9476, 0.1104, 0.0000, 0.0000],
[0.7530, 0.8874, 0.0557, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
用于产生one hot编码的向量
当没有src
值时,则所有用于填充的值均为value
值。
需要注意的时候,这个时候index.shape[dim]
必须与result.shape[dim]
相等,否则会报错。
result = torch.zeros(3, 5)
indices = torch.tensor([[0, 1, 2, 0, 0],
[2, 0, 3, 1, 2],
[2, 1, 3, 1, 4]])
result.scatter_(1, indices, value=1)
输出为
tensor([[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[0., 1., 1., 1., 1.]])
例如 indices = [1,2,3,4,5],将他转换为one-hot的形式.
indices = torch.tensor(list(range(5))).view(5,1)
result = torch.zeros(5, 5)
result.scatter_(1, indices, 1)
输出为
tensor([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.]])