在上一节写到的 Pytorch完整训练自己的数据集 ,我们可以将训练好的模型和演示代码写入网站中,方便演示。还是以pokeman数据分类为例子。
因为整个用到的是python代码,所以构建网页我们也是采用python语言,用的框架是:Flask
这一步先要把训练好的加载,直接去预测未见过的图片。代码命名为demp.py如下:
import torch
from torch import optim, nn
from torchvision import transforms
from torchvision.models import resnet18
from utils import Flatten, softmax
from PIL import Image
import os
import numpy as np
def predicts(img):
device = torch.device('cuda')
torch.manual_seed(1234)
resize = 224
className = {
'0': 'bulbasaur',
'1': 'charmander',
'2': 'mewtw',
'3': 'pikachu',
'4': 'squirtle'}
# model = ResNet18(5).to(device)
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1], # [b, 512, 1, 1]
Flatten(), # [b, 512, 1, 1] => [b, 512]
nn.Linear(512, 5)
).to(device)
# x = torch.randn(2, 3, 224, 224)
# print(model(x).shape)
basepath = os.path.dirname(__file__) # 当前文件所在路径
ckpt_path = os.path.join(basepath, 'best.mdl')
print(ckpt_path)
model.load_state_dict(torch.load(ckpt_path))
print('loaded from ckpt!')
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'), # string path= > image data
transforms.Resize(
(int(resize), int(resize))),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# img = 'pokeman\\pikachu\\00000003.jpg'
x = tf(img)
x = x.unsqueeze(0)
x = x.to(device)
model.eval()
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1).item()
prob = np.max(softmax(logits.cpu().numpy()), axis=1)[0]
# print('Our model predicts : %s'%className[str(pred)])
return className[str(pred)], str(round(prob * 100, 2)) + '%'
if __name__ == '__main__':
img = r'C:\spyder\imgshow\static\images\00000009.png'
pres = predicts(img)
需要注意的是由于不需要训练,只是测试,需要添加:model.eval()
,同时我们不需要求导求梯度,因此在模型运算的前面加上with torch.no_grad():
上面函数返回了预测的类别,以及置信度,这个需要显示在网页上面
2、构建一个简单的网站这一步我们采用Flask写一个很简单的网站,代码命名为resnet_class.py:
# coding:utf-8
from flask import Flask, render_template, request, redirect, url_for, make_response, jsonify
from werkzeug.utils import secure_filename
import os
import cv2
import time
from demo import predicts
from datetime import timedelta
# 设置允许的文件格式
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'JPG', 'PNG', 'bmp'])
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
app = Flask(__name__)
# 设置静态文件缓存过期时间
app.send_file_max_age_default = timedelta(seconds=1)
# @app.route('/resnet', methods=['POST', 'GET'])
@app.route('/', methods=['POST', 'GET']) # 添加路由
def upload():
if request.method == 'POST':
f = request.files['file']
if not (f and allowed_file(f.filename)):
return jsonify({"error": 1001,
"msg": "请检查上传的图片类型,仅限于png、PNG、jpg、JPG、bmp"})
user_input = request.form.get("name")
basepath = os.path.dirname(__file__) # 当前文件所在路径
upload_path = os.path.join(
basepath,
'static/images',
secure_filename(
f.filename)) # 注意:没有的文件夹一定要先创建,不然会提示没有该路径
# upload_path = os.path.join(basepath, 'static/images','test.jpg')
# #注意:没有的文件夹一定要先创建,不然会提示没有该路径
f.save(upload_path)
# 使用Opencv转换一下图片格式和名称
img = cv2.imread(upload_path)
cv2.imwrite(os.path.join(basepath, 'static/images', 'test.jpg'), img)
pres, pro = predicts(upload_path)
print(upload_path)
return render_template(
'upload_ok.html',
userinput=user_input,
classresult=pres,
classpro=pro,
val1=time.time())
return render_template('upload.html')
if __name__ == '__main__':
# app.debug = True
app.run(host='127.0.0.1', port=5000, debug=True)
说明:
1、在运行该代码的时候,需要在终端运行:
set FLASK_APP=resnet_class.py
flask run
便可以运行该代码。
2、下面这句话是个装饰器,可以看之前的写什么是装饰器:https://blog.csdn.net/lifei1229/article/details/105757933
@app.route('/', methods=['POST', 'GET']) # 添加路由
在route(’/’) 中传入’/'表示根目录,即在输入网站不需要加上后面目录:
如果我们写成了
@app.route(’/resnet’, methods=[‘POST’, ‘GET’])
那么网站后面需要添加resnet:
这一步就是构建上传图片,显示图片,显示分类结果的网页了。
下面是原始网页upload.html
使用ResNet分类图像演示平台
使用ResNet分类图像演示平台
请输入你认为这张图片的分类标签:
如果图片上传成功,则会用到下面这个网页,命名为 upload_ok.html
使用ResNet分类图像演示平台
使用ResNet分类图像演示平台
请输入你认为这张图片的分类标签:
阁下认为这张照片是:{{userinput}}!
我们使用ResNet模型预测,有{{classpro}}概率认为它是 {{classresult}}
4、整个文件的层级结构
由于网页和python代码交互需要用到
from flask import Flask, render_template,
且文件的存放位置也有要求,在本例子中:
templates是存放网页代码的文件夹
再看一个例子:
将网站代码放在服务器上,且外网能访问,便可以向别人演示你的深度学习模型的效果了,不只是分类,图像去噪,增强,检测,分割都可以弄一个简单的网站,向别人展示你的优秀的模型效果。