在深度学习中,在将原始数据进行清理、规范化和编码后,就需要将数据进行序列化和批量化,而Pytorch提供这两项功能的类分别为Dataset
和DataLoader
。
Dataset类是将数据进行序列化封装的类,我们在为每个具体问题定制合适的Dataset子类时,仅需要继承该父类,同时覆写__init__
、__getitem__
和__len__
三个魔鬼方法即可:
__init__
:类的初始化,在用来用于设置文件路径、导入文件和定义必要的变量。
__getitem__
:提供一个切片的方法,可以根据输入的index,获取对应的一个数据。
__len__
:用于统计数据样本的总量。
2. DataLoader类
在用Dataset类对数据封装完后进行训练和测试时,还需要对数据进行批量化处理,以供每个min-batch的数据。该类一般无需改写,直接加载对应的Dataset类,并设置相应的参数即可生成一个包含min-batch数据的可迭代对象。
MyDataLoader = DataLoader(dataset=MyDataset, batch_size=512, shuffle=True, num_workers=4)
如上面的例子,DataLoader的四个主要参数定义了数据批量化的主要属性,具体包括:
dataset
: Dataset子类,即序列化好的数据。
batch_size
: min-batch的尺寸。
shuffle
: 在每个epoch取样前,是否先打乱数据顺序。
num_workers
:所用的子进程数,默认为0,即仅用主进程。
除此之外,还有两个参数可能会用到:
sampler
: Sample子类,定义了数据进行采样的方式。之前的shuffle=True
其实也提供了一种采样方法,所以当设置sampler
参数时,必须设置shuffle=False
collate_fn
: 用于对Dataset中采样得到的每个mini-batch数据进行后处理,从而提供更好的模型输入数据,其取值为一个外部定义的可调用函数。也就是说,设置该值后,真正迭代输出的值是经过该函数处理后的返回值。该函数的具体使用可参照博文
3. 简单示例
import torch
from torch.utils.data import Dataset, DataLoader
A = torch.randn(128, 3)
C = torch.randn(128, 1)
# 1. 用Dataset封装数据集,仅做示范,实际可直接用TensorDataset封装
class MyDataset(Dataset):
def __init__(self, x, y):
assert x.size(0)==y.size(0)
self.x, self.y = x, y
def __getitem__(self, idx):
return (self.x[idx], self.y[idx])
def __len__(self):
return self.x.size(0)
# 2. 用DataLoader定义数据批量迭代器
MyDataLoader = DataLoader(dataset=dataset, shuffle=True, batch_size=4)
for data_iter in MyDataLoader:
# 进行训练或预测