看完秒懂torch.stack()

Ginger ·
更新时间:2024-09-20
· 554 次阅读

torch.stack ()在这里插入图片描述一、准备数据二、dim=0三、dim=1四、dim=2 在这里插入图片描述 一、准备数据

首先把基本的数据准备好:

import torch import numpy as np # 创建3*3的矩阵,a、b a=np.array([[1,2,3],[4,5,6],[7,8,9]]) b=np.array([[10,20,30],[40,50,60],[70,80,90]]) # 将矩阵转化为Tensor a = torch.from_numpy(a) b = torch.from_numpy(b) # 打印a、b、c print(a) print(b) output: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.int32) tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]], dtype=torch.int32) 二、dim=0

首先,一起来看看dim=0的时候,结果会是怎么样

d = torch.stack((a, b), dim=0) print(d) print(d.size()) tensor([[[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9]], [[10, 20, 30], [40, 50, 60], [70, 80, 90]]], dtype=torch.int32) torch.Size([2, 3, 3])

观察结果,可以得出结论:

dim = 0,原来的每一个矩阵也变成了一个维度 一个矩阵看做一个整体,有几个矩阵,新的维度就是几,第几个矩阵就是第几维;

如下,取出第1维度的矩阵(下标从0开始):

d[0] tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.int32)

可以很清楚的看到这就是stack前的第一个矩阵。

三、dim=1

那么,dim=1的时候,结果会是怎么样

d = torch.stack((a, b), dim=1) print(d) print(d.size()) tensor([[[ 1, 2, 3], [10, 20, 30]], [[ 4, 5, 6], [40, 50, 60]], [[ 7, 8, 9], [70, 80, 90]]], dtype=torch.int32) torch.Size([3, 2, 3])

为了观察方便,把原始数据也再拿过来对照。

# 原始数据 tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.int32) tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]], dtype=torch.int32)

可以得出结论:

将每个矩阵的第一行组成第一维矩阵,依次下去,每个矩阵的第n行组成第n维矩阵。size=(n,i,y)

四、dim=2

最后,dim=1的时候,结果会是怎么样

d = torch.stack((a, b), dim=2) print(d) print(d.size()) tensor([[[ 1, 10], [ 2, 20], [ 3, 30]], [[ 4, 40], [ 5, 50], [ 6, 60]], [[ 7, 70], [ 8, 80], [ 9, 90]]], dtype=torch.int32) torch.Size([3, 3, 2]) # 原始数据 tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.int32) tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]], dtype=torch.int32)

凭感觉画了一张图:
在这里插入图片描述
感觉像羊肉串哈哈哈哈。看图,看图
左上角第一行第一根橙色的线串起来的数字代表第一维的第一行也就是[1, 10]
横向第一行第二根橙色的线串起来的数字代表第一维的第二行也就是[2, 20]

如果有更多,如下:
在这里插入图片描述

依次过去,第一行第n根橙色的线组成第一维度的第n行(这)
以此类推,再回头看看数据集的结果,是不是很清晰啦。

最后,还是老样子,鸡汤走起。
不要假装自己很努力,结果不会陪你演戏。


作者:不堪沉沦



stack torch

需要 登录 后方可回复, 如果你还没有账号请 注册新账号