pytorch+resnet18实现长尾数据集分类(一)

Lana ·
更新时间:2024-09-20
· 981 次阅读

实验基于论文: Class-Balanced Loss Based on Effective Number of Samples

Class-balanced-loss代码地址:https://github.com/vandit15/Class-balanced-loss-pytorch

resnet18代码参考链接:https://blog.csdn.net/sunqiande88/article/details/80100891

制作数据集

论文中通过公式n=niuin = n_iu^in=ni​ui,iii为类索引.制作长尾cifar10数据集.以下代码以不均匀比例100为例.也可以通过科学上网在谷歌云链接下载.

loadcifar.py

import torch import torch.utils.data as Data import torchvision.transforms as transforms import numpy as np from PIL import Image def unpickle(file): import pickle with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') return dict # 从源文件读取数据 # 返回 train_data[12406,3072]和labels[12406] # test_data[10000,3072]和labels[10000] def get_data(train=False): data = None labels = None new_data = None new_labels = [] if train == True: for i in range(1, 6): batch = unpickle('data/cifar-10-batches-py/data_batch_' + str(i)) if i == 1: data = batch[b'data'] labels = batch[b'labels'] else: data = np.concatenate([data, batch[b'data']]) labels = np.concatenate([labels, batch[b'labels']]) count = np.zeros((10),dtype=np.int) for i in range(len(labels)): labels[i] = labels[i].reshape(1,1) data[i] = data[i].reshape((1,3072)) if count[labels[i]] < int(np.floor(5000 * ((1 / 100) ** (1 / 9)) ** (labels[i]))): count[labels[i]] += 1 if i == 0: new_data = data[i] else: new_data = np.concatenate([new_data,data[i]]) new_labels.append(labels[i]) else: continue new_labels = np.array(new_labels) new_data = new_data.reshape(-1,3072) else: batch = unpickle('data/cifar-10-batches-py/test_batch') new_data = batch[b'data'] new_labels = batch[b'labels'] return new_data, new_labels # 图像预处理函数,Compose会将多个transform操作包在一起 # 对于彩色图像,色彩通道不存在平稳特性 transform = transforms.Compose([ # ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C) # 从0到255的值映射到0到1的范围内,并转化成Tensor格式。 transforms.ToTensor(), # Normalize函数将图像数据归一化到[-1,1] transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 将标签转换为torch.LongTensor def target_transform(label): label = np.array(label) target = torch.from_numpy(label).long() return target ''' 自定义数据集读取框架来载入cifar10数据集 需要继承data.Dataset ''' class Cifar10_Dataset(Data.Dataset): def __init__(self, train=True, transform=None, target_transform=None): # 初始化文件路径 self.transform = transform self.target_transform = target_transform self.train = train # 载入训练数据集 if self.train: self.train_data, self.train_labels = get_data(train) num = self.train_data.shape[0] self.train_data = self.train_data.reshape((num, 3, 32, 32)) # 将图像数据格式转换为[height,width,channels]方便预处理 self.train_data = self.train_data.transpose((0, 2, 3, 1)) # 载入测试数据集 else: self.test_data, self.test_labels = get_data() self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.test_data = self.test_data.transpose((0, 2, 3, 1)) pass def __getitem__(self, index): # 从数据集中读取一个数据并对数据进行 # 预处理返回一个数据对,如(data,label) if self.train: img, label = self.train_data[index], self.train_labels[index] else: img, label = self.test_data[index], self.test_labels[index] img = Image.fromarray(img) # 图像预处理 if self.transform is not None: img = self.transform(img) # 标签预处理 if self.target_transform is not None: target = self.target_transform(label) return img, target def __len__(self): # 返回数据集的size if self.train: return len(self.train_data) else: return len(self.test_data) if __name__ == '__main__': # 读取训练集和测试集 train_data = Cifar10_Dataset(True, transform, target_transform) print('size of train_data:{}'.format(train_data.__len__())) test_data = Cifar10_Dataset(False, transform, target_transform) print('size of test_data:{}'.format(test_data.__len__()))

第二步:定义损失函数
第三步:训练

景唯acr 原创文章 26获赞 32访问量 4万+ 关注 私信 展开阅读全文
作者:景唯acr



pytorch 数据集 数据 分类 resnet

需要 登录 后方可回复, 如果你还没有账号请 注册新账号