torch.stcak()函数对多个张量在维度上进行叠加。
其中参数dim代表不同的维度。
具体如下代码所示:
>>> a = torch.ones(3,3)
>>> a
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
>>> b = torch.ones(3,3) + 1
>>> b
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])
>>> c = torch.ones(3,3) + 2
>>> c
tensor([[3., 3., 3.],
[3., 3., 3.],
[3., 3., 3.]])
# 当dim=0时,不同的张量直接叠加
>>> d = torch.stack((a,b,c),0)
>>> d
tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]],
[[3., 3., 3.],
[3., 3., 3.],
[3., 3., 3.]]])
#当dim=1时,不同的张量在第一维度组合,并叠加
>>> d = torch.stack((a,b,c),1)
>>> d
tensor([[[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.]],
[[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.]],
[[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.]]])
#当dim=2时,不同的张量在第二维度组合,并叠加
>>> d = torch.stack((a,b,c),2)
>>> d
tensor([[[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]],
[[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]],
[[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]]])
# 当dim=-1时,就是在最后一个维度组合,并叠加
>>> d = torch.stack((a,b,c),-1)
>>> d
tensor([[[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]],
[[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]],
[[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]]])
torch.cat()函数对多个张量进行某一维度的拼接,拼接后的总维度数不变。
其中参数dim代表了不同的维度。
解析来我们从代码中进行分析:
# 当dim=0时,从第一维度进行拼接
>>> d = torch.cat((a,b,c),0)
>>> d
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.],
[3., 3., 3.],
[3., 3., 3.],
[3., 3., 3.]])
# 当dim=1时,从第二维度进行拼接
>>> d = torch.cat((a,b,c),1)
>>> d
tensor([[1., 1., 1., 2., 2., 2., 3., 3., 3.],
[1., 1., 1., 2., 2., 2., 3., 3., 3.],
[1., 1., 1., 2., 2., 2., 3., 3., 3.]])
# 当dim=-1时,从对后一个维度进行拼接
>>> d = torch.cat((a,b,c),-1)
>>> d
tensor([[1., 1., 1., 2., 2., 2., 3., 3., 3.],
[1., 1., 1., 2., 2., 2., 3., 3., 3.],
[1., 1., 1., 2., 2., 2., 3., 3., 3.]])
注意:torch.cat()函数存在特例。若用torch.unsqueeze()函数对上述的a,b张量进行升维,在用torch.cat()函数可进行通道数叠加操作。
如下面的代码所示:
>>> a = a.unsqueeze(0)
>>> a
tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]])
>>> a.size()
torch.Size([1, 3, 3])
>>> b = b.unsqueeze(0)
>>> b
tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]])
>>> b.size()
torch.Size([1, 3, 3])
# 在没有指定dim时默认为通道数叠加
>>> c = torch.cat((a,b))
>>> c
tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]])
# dim=0时是第一维拼接,即通道数叠加
>>> c = torch.cat((a,b),0)
>>> c
tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]])
>>> c.size()
torch.Size([2, 3, 3])
# dim=1时第二维拼接
>>> c = torch.cat((a,b),1)
>>> c
tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]])
>>> c.size()
torch.Size([1, 6, 3])