# -*- coding: utf-8 -*-
"""
Spyder Editor
This is a temporary script file.
"""
import sys
nets_path=r'slim'
if nets_path not in sys.path:
sys.path.insert(0,nets_path)
else:
print('already add slim')
import tensorflow as tf
form PIL import Image
import matplotlib.pyplot as plt
from nets.nasnet import pnasnet
import numpy as np
from datasets import imagenet
slim=tf.contrib.slim
tf.reset_default_graph()
'''获得图片的尺寸'''
image_size=pnasnet.build_pnasnet_large.default_iamge_size
labels=imagenet.create_readable_name_for_imagenet_labels()
print(len(labels),labels)
def getone(onestr):
return onestr.replace(',','')
with open('中文标签.csv‘,'r+') as f:
labels=list(map(getone,list(f)))
print(len(labels),type(labels),labels(:5))
'''用A1模型识别图像'''
sampe_images=['hy.jpg','ps.jpg','72.jpg']
input_imgs=tf.placeholder(tf.float32,[None,image_size,image_size,3])
x1=2*(input_imgs/255.0)-1.0
arg_scope=pnasnet.pnasnet_large_arg_scope()
with slim.arg_scope(arg_scope):
logits,end_points=pnasnet.build_pnasnet_large(x1,num_classes=1001,is_training=False)
prob=end_points['Predictions]
y=tf.argmax(prob,axis=1)
checkpoint_file=r'pnasnet-5_large_2017_12_13/model.ckpt'
saver=tf.train.Saver #定义saver,用于加载模型
with tf.Session() as sess:
saver.restore(sess,checkpoint_file)
def preimg(img):
ch=1
if img.mode=='RGB':
ch=4
imgnp=np.asarray(img.resize((image_size,image_size),dtype=np.float32).reshape(image_size,image_size,ch)
return imgnp[:,:,3]
batchImg=[preing(Image.open(imgfilename)) for imgfilename in sampe_images]
orgImg=[Image.open(imgfilename) for imgfilename in sampe_images]
yv,img_norm=sess.run([y,x1],feed_dict={input_imgs:batchimg})
print(yv,np.shape(yv))
def showresult(yv,img_norm,img_org):
plt.figure()
p1=plt.subplot(121)
p2=plt.subplot(122)
p1.imshow(img_org)
p1.axis('off')
p1.set_title('organization image')
p2.imshow((img_norm*255).astype(np.uint8)
p2.axis('off')
p2.set_title("input image")
plt.show()
print(yy,labels(yy))
for yy,img1,img2 in zip(yv,batchImg,orgImg):
showresult(yy,img1,img2)