pytorch:pokemon+resnet详细代码+数据集

Ona ·
更新时间:2024-09-20
· 588 次阅读

文章目录一.定义一个Pokemon的类,用于获取图片以及对应的label二.构建resblock三.搭建resnet四.设置一些超参数五.载入数据六.初始化模型,设置loss_function/optimizer/evaluation七.开始训练,并进行检验 import torch from torch import nn from torch.nn import functional as F from torchvision import transforms from torch.utils.data import DataLoader,Dataset from torch import optim import os import csv from PIL import Image import warnings warnings.simplefilter('ignore') 一.定义一个Pokemon的类,用于获取图片以及对应的label

pokemon数据集请戳:
缦旋律的资源合集.

对于自定义数据集,并使用DataLoader划分batch不熟悉的,可以戳:
自定义数据集+DataLoader.

class Pokemon(Dataset): def __init__(self,root,resize,mode): #root是文件路径,resize是对原始图片进行裁剪,mode是选择模式(train、test、validation) super(Pokemon,self).__init__() self.root = root self.resize = resize self.name2label = {} #给每个种类分配一个数字,以该数字作为这一类别的label #name是宝可梦的种类,e.g:pikachu for name in sorted(os.listdir(os.path.join(self.root))): #listdir返回的顺序不固定,加上一个sorted使每一次的顺序都一样 if not os.path.isdir(os.path.join(self.root,name)):#os.path.isdir()用于判断括号中的内容是否是一个未压缩的文件夹 continue self.name2label[name] = len(self.name2label.keys()) print(self.name2label) self.images,self.labels = self.load_csv('images&labels.csv') #将全部数据分成train、validation、test if mode == 'train': #前60%作为训练集 self.images = self.images[:int(0.6*len(self.images))] self.labels = self.labels[:int(0.6*len(self.labels))] elif mode == 'val': #60%~80%作为validation self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))] self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))] else: #后20%作为test set self.images = self.images[int(0.8*len(self.images)):] self.labels = self.labels[int(0.8*len(self.labels)):] def load_csv(self,filename): #载入原始图片的路径,并保存到指定的CSV文件中,然后从该CSV文件中再次读入所有图片的存储路径和label。 #如果CSV文件已经存在,则直接读入该CSV文件的内容 #为什么保存的是图片的路径而不是图片?因为直接保存图片可能会造成内存爆炸 if not os.path.exists(os.path.join(self.root,filename)): #如果filename这个文件不存在,那么执行以下代码,创建file images = [] for name in self.name2label.keys(): #glob.glob()返回的是括号中的路径中的所有文件的路径 # += 是把glob.glob()返回的结果依次append到image中,而不是以一个整体append # 这里只用了png/jpg/jepg是因为本次实验的图片只有这三种格式,如果有其他格式请自行添加 images += glob.glob(os.path.join(self.root,name,'*.png')) images += glob.glob(os.path.join(self.root,name,'*.jpg')) images += glob.glob(os.path.join(self.root,name,'*.jpeg')) print(len(images)) random.shuffle(images) #把所有图片路径顺序打乱 with open(os.path.join(self.root,filename),mode='w',newline='') as f: #将图片路径及其对应的数字标签写到指定文件中 writer = csv.writer(f) for img in images: #img e.g:'./pokemon/pikachu\\00000001.png' name = img.split(os.sep)[-2] #即取出‘pikachu’ label = self.name2label[name] #根据name找到对应的数字标签 writer.writerow([img,label]) #把每张图片的路径和它对应的数字标签写到指定的CSV文件中 print('image paths and labels have been writen into csv file:',filename) #把数据读入(如果filename存在就直接执行这一步,如果不存在就先创建file再读入数据) images,labels = [],[] with open(os.path.join(self.root,filename)) as f: reader = csv.reader(f) for row in reader: img,label = row label = int(label) images.append(img) labels.append(label) assert len(images) == len(labels) #确保它们长度一致 return images,labels def __len__(self): return len(self.images) def __getitem__(self,idx): img,label = self.images[idx],self.labels[idx]#此时img还是路径字符串,要把它转化成tensor #将图片resize成224*224,并转化成tensor,这个tensor的size是3*224*224(3是因为有RGB3个通道) trans = transforms.Compose(( lambda x: Image.open(x).convert('RGB'), transforms.Resize((self.resize,self.resize)), #必须要把长宽都一起写上啊!!! transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) #这个数据是根据resnet中的图片统计得到的,直接拿来用就好 )) img = trans(img) label = torch.tensor(label) return img,label 二.构建resblock

对于resnet的构建不熟悉的,可以戳:
cifar-10+resnet 详细代码+解释.

class resblock(nn.Module): def __init__(self,ch_in,ch_out,stride=1): super(resblock,self).__init__() self.conv_1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1) self.bn_1 = nn.BatchNorm2d(ch_out) self.conv_2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1) self.bn_2 = nn.BatchNorm2d(ch_out) self.ch_trans = nn.Sequential() if ch_in != ch_out: self.ch_trans = nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),nn.BatchNorm2d(ch_out)) #ch_trans表示通道数转变。因为要做short_cut,所以x_pro和x_ch的size应该完全一致 def forward(self,x): x_pro = F.relu(self.bn_1(self.conv_1(x))) x_pro = self.bn_2(self.conv_2(x_pro)) #short_cut: x_ch = self.ch_trans(x) out = x_pro + x_ch out = F.relu(out) return out 三.搭建resnet class Resnet18(nn.Module): def __init__(self,num_class): super(Resnet18,self).__init__() self.conv_1 = nn.Sequential( nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0), nn.BatchNorm2d(16)) self.block1 = resblock(16,32,3) self.block2 = resblock(32,64,3) self.block3 = resblock(64,128,2) self.block4 = resblock(128,256,2) self.outlayer = nn.Linear(256*3*3,num_class)#这个256*3*3是根据forward中x经过4个resblock之后来决定的 def forward(self,x): x = F.relu(self.conv_1(x)) x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = x.reshape(x.size(0),-1) result = self.outlayer(x) return result 四.设置一些超参数 batch_size = 32 lr = 1e-3 device = torch.device('cuda') torch.manual_seed(1234) 五.载入数据 train_db = Pokemon('./pokemon',224,'train') #将所有图片(顺序已打乱)的前60%作为train_set val_db = Pokemon('./pokemon',224,'val') #60%~80%作为validation_set test_db = Pokemon('./pokemon',224,'test') #80%~100%作为test_set train_loader = DataLoader(train_db,batch_size=batch_size,shuffle=True) #之后调用一次train_loader就会把train_db划分成很多batch val_loader = DataLoader(val_db,batch_size=batch_size,shuffle=True) test_loader = DataLoader(test_db,batch_size=batch_size,shuffle=True) {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4} {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4} {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4} 六.初始化模型,设置loss_function/optimizer/evaluation model = Resnet18(5).to(device) #模型初始化,5代表一共有5种类别 print('模型需要训练的参数共有{}个'.format(sum(map(lambda p:p.numel(),model.parameters())))) loss_fn = nn.CrossEntropyLoss() #选择loss_function optimizer = optim.Adam(model.parameters(),lr=lr) #选择优化方式 # evaluate用于检测模型的预测效果,validation_set和test_set是同样的evaluate方法 def evaluate(model,loader): correct_num = 0 total_num = len(loader.dataset) for img,label in loader: #lodaer中包含了很多batch,每个batch有32张图片 img,label = img.to(device),label.to(device) with torch.no_grad(): logits = model(img) pre_label = logits.argmax(dim=1) correct_num += torch.eq(pre_label,label).sum().float().item() return correct_num/total_num 七.开始训练,并进行检验 best_epoch,best_acc = 0,0 for epoch in range(10): #时间关系,我们只训练10个epoch for batch_num,(img,label) in enumerate(train_loader): #img.size [b,3,224,224] label.size [b] img,label = img.to(device),label.to(device) logits = model(img) loss = loss_fn(logits,label) if batch_num%5 == 0: print('这是第{}次迭代的第{}个batch,loss是{}'.format(epoch+1,batch_num+1,loss.item())) optimizer.zero_grad() loss.backward() optimizer.step() if epoch%2==0: #这里设置的是每训练两次epoch就进行一次validation val_acc = evaluate(model,val_loader) #如果val_acc比之前的好,那么就把该epoch保存下来,并把此时模型的参数保存到指定txt文件里 if val_acc>best_acc: print('验证集上的准确率是:{}'.format(val_acc)) best_epoch = epoch best_acc = val_acc torch.save(model.state_dict(),'pokemon_ckp.txt') print('best_acc:{},best_epoch:{}'.format(best_acc,best_epoch)) model.load_state_dict(torch.load('pokemon_ckp.txt')) print('模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set') test_acc = evaluate(model,test_loader) print('测试集上的准确率是:{}'.format(test_acc)) 这是第1次迭代的第1个batch,loss是0.09284534305334091 这是第1次迭代的第6个batch,loss是0.048977259546518326 这是第1次迭代的第11个batch,loss是0.012712553143501282 这是第1次迭代的第16个batch,loss是0.29784664511680603 这是第1次迭代的第21个batch,loss是0.04697053134441376 验证集上的准确率是:0.9055793991416309 这是第2次迭代的第1个batch,loss是0.1201871708035469 这是第2次迭代的第6个batch,loss是0.17532214522361755 这是第2次迭代的第11个batch,loss是0.4216356873512268 这是第2次迭代的第16个batch,loss是0.25439736247062683 这是第2次迭代的第21个batch,loss是0.08713945746421814 这是第3次迭代的第1个batch,loss是0.29146236181259155 这是第3次迭代的第6个batch,loss是0.054789893329143524 这是第3次迭代的第11个batch,loss是0.4522630572319031 这是第3次迭代的第16个batch,loss是0.08970004320144653 这是第3次迭代的第21个batch,loss是0.3429204821586609 这是第4次迭代的第1个batch,loss是0.06742320954799652 这是第4次迭代的第6个batch,loss是0.12285976856946945 这是第4次迭代的第11个batch,loss是0.15735840797424316 这是第4次迭代的第16个batch,loss是0.07834229618310928 这是第4次迭代的第21个batch,loss是0.20532763004302979 这是第5次迭代的第1个batch,loss是0.00593993067741394 这是第5次迭代的第6个batch,loss是0.03216344118118286 这是第5次迭代的第11个batch,loss是0.03481002524495125 这是第5次迭代的第16个batch,loss是0.15314869582653046 这是第5次迭代的第21个batch,loss是0.08527624607086182 这是第6次迭代的第1个batch,loss是0.05515890568494797 这是第6次迭代的第6个batch,loss是0.036611974239349365 这是第6次迭代的第11个batch,loss是0.007195517420768738 这是第6次迭代的第16个batch,loss是0.05695120990276337 这是第6次迭代的第21个batch,loss是0.15042126178741455 这是第7次迭代的第1个batch,loss是0.1088687926530838 这是第7次迭代的第6个batch,loss是0.002063468098640442 这是第7次迭代的第11个batch,loss是0.01613890379667282 这是第7次迭代的第16个batch,loss是0.012490876019001007 这是第7次迭代的第21个batch,loss是0.48446154594421387 验证集上的准确率是:0.9141630901287554 这是第8次迭代的第1个batch,loss是0.10298655182123184 这是第8次迭代的第6个batch,loss是0.05644068121910095 这是第8次迭代的第11个batch,loss是0.0563386008143425 这是第8次迭代的第16个batch,loss是0.00903283804655075 这是第8次迭代的第21个batch,loss是0.08256962895393372 这是第9次迭代的第1个batch,loss是0.014249928295612335 这是第9次迭代的第6个batch,loss是0.013826802372932434 这是第9次迭代的第11个batch,loss是0.0016943514347076416 这是第9次迭代的第16个batch,loss是0.1954154521226883 这是第9次迭代的第21个batch,loss是0.056067951023578644 这是第10次迭代的第1个batch,loss是0.014393903315067291 这是第10次迭代的第6个batch,loss是0.26919856667518616 这是第10次迭代的第11个batch,loss是0.03811478987336159 这是第10次迭代的第16个batch,loss是0.18677780032157898 这是第10次迭代的第21个batch,loss是0.018675178289413452 best_acc:0.9141630901287554,best_epoch:6 模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set 测试集上的准确率是:0.9017094017094017

可以看到,最优的epoch是6,即第7次迭代时候的模型,该模型在validation_set上的正确率是0.914,在test_set上的准确率是0.902


在这里插入图片描述


作者:缦旋律



pokemon pytorch 数据集 数据 resnet

需要 登录 后方可回复, 如果你还没有账号请 注册新账号