利用torch.utils.data.Dataset自定义数据加载类

Emily ·
更新时间:2024-09-20
· 894 次阅读

import torch as t from torch.utils import data import os from PIL import Image import numpy as np import torchvision.transforms as T transforms = T.Compose([   T.Resize(224),   T.CenterCrop(224),   T.ToTensor(),   T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) # 继承Dataset类要重写__getitem__()和__len__() class CatDog(data.Dataset):   def __init__(self, root, transforms=None):     imgs = os.listdir(root)     self.imgs = [os.path.join(root, img) for img in imgs]     self.transforms = transforms   def __getitem__(self, index):     label = 1 if dog else 0     data = Image.open(self.imgs[index])     if self.transform:       data = self.transform(data)     return data, label   def __len__(self):     return len(self.imgs)
作者:枫叶



Dataset 数据 torch

需要 登录 后方可回复, 如果你还没有账号请 注册新账号