pytorch中函数tensor.numpy()的数据类型解析

Isoke ·
更新时间:2024-11-13
· 1366 次阅读

目录

函数tensor.numpy()的数据类型

tensor数据和numpy数据转换中注意的一个问题

函数tensor.numpy()的数据类型

今天写代码的时候,要统计一下标签数据里出现的类别总数和要分类的分类数是不是一致的。

我的做法是把tensor类型的数据转变成list,然后用Counter函数做统计。

代码如下:

from collections import Counter List_counter = Counter(List1) #List1就是待统计的数据,是一维的列表。生成的List_counter是一个字典,键是数据, #对应的值是数据出现的频率

在做这个统计的时候,突然发现,我的数据是float的类型,这是不应该出现的,因为标签数据在处理的时候都是整型数据。

经过一番查找后,发现是tensor.numpy()返回值数据类型的原因。这个函数的返回值是float类型的

tensor数据和numpy数据转换中注意的一个问题

在pytorch中,把numpy.array数据转换到张量tensor数据的常用函数是torch.from_numpy(array)或者torch.Tensor(array)

第一种函数更常用,然而在pytorch0.4中已经舍弃了这种函数

下面一个简单的编程实验说明这两种方法的区别

实验在pytorch0.4框架下进行

运行程序之后,结果是

可以看出修改数组a的元素值,张量b的元素值也改变了,但是张量c却不变。

修改张量c的元素值,数组a和张量b的元素值都不变。

这说明torch.from_numpy(array)是做数组的浅拷贝,torch.Tensor(array)是做数组的深拷贝

以上为个人经验,希望能给大家一个参考,也希望大家多多支持软件开发网。



pytorch NumPy 数据类型 数据 tensor

需要 登录 后方可回复, 如果你还没有账号请 注册新账号