【Pytorch】简析数据批量化处理类Dataset和DataLoader

Rhoda ·
更新时间:2024-11-10
· 530 次阅读

在深度学习中,在将原始数据进行清理、规范化和编码后,就需要将数据进行序列化和批量化,而Pytorch提供这两项功能的类分别为DatasetDataLoader

1. Dataset类

Dataset类是将数据进行序列化封装的类,我们在为每个具体问题定制合适的Dataset子类时,仅需要继承该父类,同时覆写__init____getitem____len__三个魔鬼方法即可:

__init__:类的初始化,在用来用于设置文件路径、导入文件和定义必要的变量。 __getitem__:提供一个切片的方法,可以根据输入的index,获取对应的一个数据。 __len__:用于统计数据样本的总量。 2. DataLoader类

在用Dataset类对数据封装完后进行训练和测试时,还需要对数据进行批量化处理,以供每个min-batch的数据。该类一般无需改写,直接加载对应的Dataset类,并设置相应的参数即可生成一个包含min-batch数据的可迭代对象。

MyDataLoader = DataLoader(dataset=MyDataset, batch_size=512, shuffle=True, num_workers=4)

如上面的例子,DataLoader的四个主要参数定义了数据批量化的主要属性,具体包括:

dataset: Dataset子类,即序列化好的数据。 batch_size: min-batch的尺寸。 shuffle: 在每个epoch取样前,是否先打乱数据顺序。 num_workers:所用的子进程数,默认为0,即仅用主进程。

除此之外,还有两个参数可能会用到:

sampler: Sample子类,定义了数据进行采样的方式。之前的shuffle=True其实也提供了一种采样方法,所以当设置sampler参数时,必须设置shuffle=False collate_fn: 用于对Dataset中采样得到的每个mini-batch数据进行后处理,从而提供更好的模型输入数据,其取值为一个外部定义的可调用函数。也就是说,设置该值后,真正迭代输出的值是经过该函数处理后的返回值。该函数的具体使用可参照博文 3. 简单示例 import torch from torch.utils.data import Dataset, DataLoader A = torch.randn(128, 3) C = torch.randn(128, 1) # 1. 用Dataset封装数据集,仅做示范,实际可直接用TensorDataset封装 class MyDataset(Dataset): def __init__(self, x, y): assert x.size(0)==y.size(0) self.x, self.y = x, y def __getitem__(self, idx): return (self.x[idx], self.y[idx]) def __len__(self): return self.x.size(0) # 2. 用DataLoader定义数据批量迭代器 MyDataLoader = DataLoader(dataset=dataset, shuffle=True, batch_size=4) for data_iter in MyDataLoader: # 进行训练或预测
作者:guofei_fly



pytorch Dataset 数据

需要 登录 后方可回复, 如果你还没有账号请 注册新账号