将ResNet分类器做成一个小网站界面

Welcome ·
更新时间:2024-09-20
· 639 次阅读

在上一节写到的 Pytorch完整训练自己的数据集 ,我们可以将训练好的模型和演示代码写入网站中,方便演示。还是以pokeman数据分类为例子。

因为整个用到的是python代码,所以构建网页我们也是采用python语言,用的框架是:Flask

1、首先写分类demo代码

这一步先要把训练好的加载,直接去预测未见过的图片。代码命名为demp.py如下:

import torch from torch import optim, nn from torchvision import transforms from torchvision.models import resnet18 from utils import Flatten, softmax from PIL import Image import os import numpy as np def predicts(img): device = torch.device('cuda') torch.manual_seed(1234) resize = 224 className = { '0': 'bulbasaur', '1': 'charmander', '2': 'mewtw', '3': 'pikachu', '4': 'squirtle'} # model = ResNet18(5).to(device) trained_model = resnet18(pretrained=True) model = nn.Sequential(*list(trained_model.children())[:-1], # [b, 512, 1, 1] Flatten(), # [b, 512, 1, 1] => [b, 512] nn.Linear(512, 5) ).to(device) # x = torch.randn(2, 3, 224, 224) # print(model(x).shape) basepath = os.path.dirname(__file__) # 当前文件所在路径 ckpt_path = os.path.join(basepath, 'best.mdl') print(ckpt_path) model.load_state_dict(torch.load(ckpt_path)) print('loaded from ckpt!') tf = transforms.Compose([ lambda x: Image.open(x).convert('RGB'), # string path= > image data transforms.Resize( (int(resize), int(resize))), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # img = 'pokeman\\pikachu\\00000003.jpg' x = tf(img) x = x.unsqueeze(0) x = x.to(device) model.eval() with torch.no_grad(): logits = model(x) pred = logits.argmax(dim=1).item() prob = np.max(softmax(logits.cpu().numpy()), axis=1)[0] # print('Our model predicts : %s'%className[str(pred)]) return className[str(pred)], str(round(prob * 100, 2)) + '%' if __name__ == '__main__': img = r'C:\spyder\imgshow\static\images\00000009.png' pres = predicts(img)

需要注意的是由于不需要训练,只是测试,需要添加:model.eval(),同时我们不需要求导求梯度,因此在模型运算的前面加上with torch.no_grad():

上面函数返回了预测的类别,以及置信度,这个需要显示在网页上面

2、构建一个简单的网站

这一步我们采用Flask写一个很简单的网站,代码命名为resnet_class.py:

# coding:utf-8 from flask import Flask, render_template, request, redirect, url_for, make_response, jsonify from werkzeug.utils import secure_filename import os import cv2 import time from demo import predicts from datetime import timedelta # 设置允许的文件格式 ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG', 'bmp']) def allowed_file(filename): return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS app = Flask(__name__) # 设置静态文件缓存过期时间 app.send_file_max_age_default = timedelta(seconds=1) # @app.route('/resnet', methods=['POST', 'GET']) @app.route('/', methods=['POST', 'GET']) # 添加路由 def upload(): if request.method == 'POST': f = request.files['file'] if not (f and allowed_file(f.filename)): return jsonify({"error": 1001, "msg": "请检查上传的图片类型,仅限于png、PNG、jpg、JPG、bmp"}) user_input = request.form.get("name") basepath = os.path.dirname(__file__) # 当前文件所在路径 upload_path = os.path.join( basepath, 'static/images', secure_filename( f.filename)) # 注意:没有的文件夹一定要先创建,不然会提示没有该路径 # upload_path = os.path.join(basepath, 'static/images','test.jpg') # #注意:没有的文件夹一定要先创建,不然会提示没有该路径 f.save(upload_path) # 使用Opencv转换一下图片格式和名称 img = cv2.imread(upload_path) cv2.imwrite(os.path.join(basepath, 'static/images', 'test.jpg'), img) pres, pro = predicts(upload_path) print(upload_path) return render_template( 'upload_ok.html', userinput=user_input, classresult=pres, classpro=pro, val1=time.time()) return render_template('upload.html') if __name__ == '__main__': # app.debug = True app.run(host='127.0.0.1', port=5000, debug=True)

说明:
1、在运行该代码的时候,需要在终端运行:

set FLASK_APP=resnet_class.py flask run

便可以运行该代码。

2、下面这句话是个装饰器,可以看之前的写什么是装饰器:https://blog.csdn.net/lifei1229/article/details/105757933

@app.route('/', methods=['POST', 'GET']) # 添加路由

route(’/’) 中传入’/'表示根目录,即在输入网站不需要加上后面目录:

如果我们写成了

@app.route(’/resnet’, methods=[‘POST’, ‘GET’])

那么网站后面需要添加resnet:

3、写入网页

这一步就是构建上传图片,显示图片,显示分类结果的网页了。
下面是原始网页upload.html

使用ResNet分类图像演示平台

使用ResNet分类图像演示平台


请输入你认为这张图片的分类标签:

如果图片上传成功,则会用到下面这个网页,命名为 upload_ok.html

使用ResNet分类图像演示平台

使用ResNet分类图像演示平台


请输入你认为这张图片的分类标签:

阁下认为这张照片是:{{userinput}}!

你的图片被外星人劫持了~~

我们使用ResNet模型预测,有{{classpro}}概率认为它是 {{classresult}}

4、整个文件的层级结构

由于网页和python代码交互需要用到

from flask import Flask, render_template,

且文件的存放位置也有要求,在本例子中:
在这里插入图片描述
templates是存放网页代码的文件夹

5、演示结果

再看一个例子:

6、总结

将网站代码放在服务器上,且外网能访问,便可以向别人演示你的深度学习模型的效果了,不只是分类,图像去噪,增强,检测,分割都可以弄一个简单的网站,向别人展示你的优秀的模型效果。


作者:极简机器学习



界面 resnet

需要 登录 后方可回复, 如果你还没有账号请 注册新账号
相关文章