one hot编码:`torch.Tensor.scatter_()`函数用法详解

Angie ·
更新时间:2024-09-20
· 983 次阅读

torch.Tensor.scatter_()torch.gather()函数的方向反向操作。两个函数可以看成一对兄弟函数。gather用来解码one hot,scatter_用来编码one hot。

scatter_(dim, index, src) → Tensor

dim (python:int) – 用来寻址的坐标轴 index (LongTensor) – 索引 src(Tensor) –用来scatter的源张量,以防value未被指定。 value(python:float) – 用来scatter的源张量,以防src未被指定。

现在我们来看看具体这么用,看下面这个例子就一目了然了。

dim =0 import torch x = torch.tensor([[0.9413, 0.9476, 0.1104, 0.9898, 0.6443], [0.6913, 0.8924, 0.7530, 0.8874, 0.0557]]) result = torch.zeros(3, 5) indices = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) result.scatter_(0, indices, x)

输出为

tensor([[0.9413, 0.8924, 0.7530, 0.9898, 0.6443], [0.0000, 0.9476, 0.0000, 0.8874, 0.0000], [0.6913, 0.0000, 0.1104, 0.0000, 0.0557]]) dim = 0的情形:

比如上例中,dim=0,所以根据这个规则来self[index[i][j]][j] = src[i][j]来确定替换规则。

index中的值决定了src中的值在result中的放置位置。

dim=0时,则将列固定起来,先看第0列:

对于第0行,首先找到x的第0列第0行的值为0.9413,然后在用index[0][0]的值来找将要在result中放置的位置。

在这个例子中,index[0][0]=0, 所以0.9413将放置在result[0][0]这个位置。

对于result中的各项,他们的寻址过程如下:

x[0][1] = 0.9476 -> indices[0][1]=1 -> result[ index = 1 ][1] = 0.9476

x[1][3] = 0.8874 -> indices[1][3]=1 -> result[ index = 1 ][3] = 0.8874

依此类推。

以下为dim = 1的情形:

x[0][0] = 0.9413 -> indices[0][0]=0 -> result[0][index = 0] = 0.9413

x[0][3] = 0.9898 -> indices[0][3]=0 -> result[0][index = 0] = 0.9898 ## 将上一步的值覆盖了

x[0][4] = 0.6443 -> indices[0][4]=0 -> result[0][index = 0] = 0.6443 ## 再次将上一步的值覆盖了

因此result[0][0]的值为0.6443.

dim = 1

x = torch.tensor([[0.9413, 0.9476, 0.1104, 0.9898, 0.6443], [0.6913, 0.8924, 0.7530, 0.8874, 0.0557]]) result = torch.zeros(3, 5) indices = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) result.scatter_(1, indices, x)

输出为

tensor([[0.6443, 0.9476, 0.1104, 0.0000, 0.0000], [0.7530, 0.8874, 0.0557, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]) 用于产生one hot编码的向量

当没有src值时,则所有用于填充的值均为value值。

需要注意的时候,这个时候index.shape[dim]必须与result.shape[dim]相等,否则会报错。

result = torch.zeros(3, 5) indices = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 3, 1, 2], [2, 1, 3, 1, 4]]) result.scatter_(1, indices, value=1)

输出为

tensor([[1., 1., 1., 0., 0.], [1., 1., 1., 1., 0.], [0., 1., 1., 1., 1.]]) 例如 indices = [1,2,3,4,5],将他转换为one-hot的形式. indices = torch.tensor(list(range(5))).view(5,1) result = torch.zeros(5, 5) result.scatter_(1, indices, 1)

输出为

tensor([[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.], [0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]])
作者:miguemath



函数 torch hot tensor

需要 登录 后方可回复, 如果你还没有账号请 注册新账号