Pytorch实现List Tensor转Tensor,reshape拼接等操作

Thadea ·
更新时间:2024-11-13
· 494 次阅读

目录

一、List Tensor转Tensor (torch.cat)

高维tensor

二、List Tensor转Tensor (torch.stack)

持续更新一些常用的Tensor操作,比如List,Numpy,Tensor之间的转换,Tensor的拼接,维度的变换等操作。

其它Tensor操作如 einsum等见:待更新。

用到两个函数:

torch.cat

torch.stack

一、List Tensor转Tensor (torch.cat)

// An highlighted block >>> t1 = torch.FloatTensor([[1,2],[5,6]]) >>> t2 = torch.FloatTensor([[3,4],[7,8]]) >>> l = [] >>> l.append(t1) >>> l.append(t2) >>> ta = torch.cat(l,dim=0) >>> ta = torch.cat(l,dim=0).reshape(2,2,2) >>> tb = torch.cat(l,dim=1).reshape(2,2,2) >>> ta tensor([[[1., 2.], [5., 6.]], [[3., 4.], [7., 8.]]]) >>> tb tensor([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]]) 高维tensor

** 如果理解了2D to 3DTensor,以此类推,不难理解3D to 4D,看下面代码即可明白:**

>>> t1 = torch.range(1,8).reshape(2,2,2) >>> t2 = torch.range(11,18).reshape(2,2,2) >>> l = [] >>> l.append(t1) >>> l.append(t2) >>> torch.cat(l,dim=2).reshape(2,2,2,2) tensor([[[[ 1., 2.], [11., 12.]], [[ 3., 4.], [13., 14.]]], [[[ 5., 6.], [15., 16.]], [[ 7., 8.], [17., 18.]]]]) >>> torch.cat(l,dim=1).reshape(2,2,2,2) tensor([[[[ 1., 2.], [ 3., 4.]], [[11., 12.], [13., 14.]]], [[[ 5., 6.], [ 7., 8.]], [[15., 16.], [17., 18.]]]]) >>> torch.cat(l,dim=0).reshape(2,2,2,2) tensor([[[[ 1., 2.], [ 3., 4.]], [[ 5., 6.], [ 7., 8.]]], [[[11., 12.], [13., 14.]], [[15., 16.], [17., 18.]]]]) 二、List Tensor转Tensor (torch.stack)

代码:

import torch t1 = torch.FloatTensor([[1,2],[5,6]]) t2 = torch.FloatTensor([[3,4],[7,8]]) l = [t1, t2] t3 = torch.stack(l, dim=2) print(t3.shape) print(t3) ## output: ## torch.Size([2, 2, 2]) ## tensor([[[1., 3.], ## [2., 4.]], ## [[5., 7.], ## [6., 8.]]])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持软件开发网。



pytorch reshape tensor list

需要 登录 后方可回复, 如果你还没有账号请 注册新账号