深度学习框架_PyTorch_torch.stack()函数和torch.cat()函数

Faith ·
更新时间:2024-09-20
· 769 次阅读

torch.stcak()函数对多个张量在维度上进行叠加。
其中参数dim代表不同的维度。
具体如下代码所示:

>>> a = torch.ones(3,3) >>> a tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>> b = torch.ones(3,3) + 1 >>> b tensor([[2., 2., 2.], [2., 2., 2.], [2., 2., 2.]]) >>> c = torch.ones(3,3) + 2 >>> c tensor([[3., 3., 3.], [3., 3., 3.], [3., 3., 3.]]) # 当dim=0时,不同的张量直接叠加 >>> d = torch.stack((a,b,c),0) >>> d tensor([[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[2., 2., 2.], [2., 2., 2.], [2., 2., 2.]], [[3., 3., 3.], [3., 3., 3.], [3., 3., 3.]]]) #当dim=1时,不同的张量在第一维度组合,并叠加 >>> d = torch.stack((a,b,c),1) >>> d tensor([[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]], [[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]], [[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]]) #当dim=2时,不同的张量在第二维度组合,并叠加 >>> d = torch.stack((a,b,c),2) >>> d tensor([[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]], [[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]], [[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]]) # 当dim=-1时,就是在最后一个维度组合,并叠加 >>> d = torch.stack((a,b,c),-1) >>> d tensor([[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]], [[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]], [[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]])

torch.cat()函数对多个张量进行某一维度的拼接,拼接后的总维度数不变。
其中参数dim代表了不同的维度。

解析来我们从代码中进行分析:

# 当dim=0时,从第一维度进行拼接 >>> d = torch.cat((a,b,c),0) >>> d tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [2., 2., 2.], [2., 2., 2.], [2., 2., 2.], [3., 3., 3.], [3., 3., 3.], [3., 3., 3.]]) # 当dim=1时,从第二维度进行拼接 >>> d = torch.cat((a,b,c),1) >>> d tensor([[1., 1., 1., 2., 2., 2., 3., 3., 3.], [1., 1., 1., 2., 2., 2., 3., 3., 3.], [1., 1., 1., 2., 2., 2., 3., 3., 3.]]) # 当dim=-1时,从对后一个维度进行拼接 >>> d = torch.cat((a,b,c),-1) >>> d tensor([[1., 1., 1., 2., 2., 2., 3., 3., 3.], [1., 1., 1., 2., 2., 2., 3., 3., 3.], [1., 1., 1., 2., 2., 2., 3., 3., 3.]])

注意:torch.cat()函数存在特例。若用torch.unsqueeze()函数对上述的a,b张量进行升维,在用torch.cat()函数可进行通道数叠加操作。

如下面的代码所示:

>>> a = a.unsqueeze(0) >>> a tensor([[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]) >>> a.size() torch.Size([1, 3, 3]) >>> b = b.unsqueeze(0) >>> b tensor([[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]) >>> b.size() torch.Size([1, 3, 3]) # 在没有指定dim时默认为通道数叠加 >>> c = torch.cat((a,b)) >>> c tensor([[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]) # dim=0时是第一维拼接,即通道数叠加 >>> c = torch.cat((a,b),0) >>> c tensor([[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]], [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]) >>> c.size() torch.Size([2, 3, 3]) # dim=1时第二维拼接 >>> c = torch.cat((a,b),1) >>> c tensor([[[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]]) >>> c.size() torch.Size([1, 6, 3])
作者:CV-GANRocky



stack 深度学习框架 pytorch 学习 函数 深度学习 cat 框架

需要 登录 后方可回复, 如果你还没有账号请 注册新账号