VGG对硬件要求较AlexNet高,一般CPU跑起来很慢,最好用GPU。
首先引入相关库
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import ZeroPadding2D, Convolution2D, MaxPooling2D, Dropout, Activation, Flatten
import numpy as np
from keras.utils import np_utils
from tensorflow.keras.applications.vgg19 import preprocess_input
import cv2
from keras.datasets import mnist
下面引入MNIST图像数据集。标准VGG的输入数据是224*224大小的图像,对于灰度图像4D张量是(None, 224, 224, 1),彩色图像是(None, 224, 224, 3)。对于彩色图像还可以用preprocess_input做简单的预处理。MNIST原始图像是28*28大小的,所以要调整图像尺寸。标签要转换成one-hot形式,以便采用分类模型。
(x_Train,y_Train),(x_Test,y_Test)= mnist.load_data()
x_Train=x_Train.reshape(x_Train.shape[0],28,28).astype('float32')
x_Test=x_Test.reshape(x_Test.shape[0],28,28).astype('float32')
x_Train_resized=np.array([cv2.resize(i,(224,224)) for i in x_Train])
x_Test_resized=np.array([cv2.resize(i,(224,224)) for i in x_Test])
x_Train4D=x_Train_resized.reshape(-1,224,224,1)
x_Test4D=x_Test_resized.reshape(-1,224,224,1)
y_Train_One_Hot = np_utils.to_categorical(y_Train)
y_Test_One_Hot = np_utils.to_categorical(y_Test)
下面开始构建VGG11神经网络:
model = Sequential()
model.add(ZeroPadding2D((1,1),input_shape=(224,224, 1)))
model.add(Convolution2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(256, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(256, (3, 3), activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1,1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(Convolution2D(4096, (7, 7), activation='relu'))
model.add(Dropout(0.5))
model.add(Convolution2D(4096, (1, 1), activation='relu'))
model.add(Dropout(0.5))
model.add(Convolution2D(10, (1, 1)))
model.add(Flatten())
model.add(Activation('softmax'))
VGG11神经网络构建好了,此时如果想看一下模型的参数情况,可以运行
print(model.summary())
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
zero_padding2d (ZeroPadding2 (None, 226, 226, 1) 0
_________________________________________________________________
conv2d (Conv2D) (None, 224, 224, 64) 640
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 112, 112, 64) 0
_________________________________________________________________
zero_padding2d_1 (ZeroPaddin (None, 114, 114, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 112, 112, 128) 73856
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 56, 56, 128) 0
_________________________________________________________________
zero_padding2d_2 (ZeroPaddin (None, 58, 58, 128) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 56, 56, 256) 295168
_________________________________________________________________
zero_padding2d_3 (ZeroPaddin (None, 58, 58, 256) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 56, 56, 256) 590080
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 28, 28, 256) 0
_________________________________________________________________
zero_padding2d_4 (ZeroPaddin (None, 30, 30, 256) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 28, 28, 512) 1180160
_________________________________________________________________
zero_padding2d_5 (ZeroPaddin (None, 30, 30, 512) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 28, 28, 512) 2359808
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 14, 14, 512) 0
_________________________________________________________________
zero_padding2d_6 (ZeroPaddin (None, 16, 16, 512) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
zero_padding2d_7 (ZeroPaddin (None, 16, 16, 512) 0
_________________________________________________________________
conv2d_7 (Conv2D) (None, 14, 14, 512) 2359808
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 7, 7, 512) 0
_________________________________________________________________
conv2d_8 (Conv2D) (None, 1, 1, 4096) 102764544
_________________________________________________________________
dropout (Dropout) (None, 1, 1, 4096) 0
_________________________________________________________________
conv2d_9 (Conv2D) (None, 1, 1, 4096) 16781312
_________________________________________________________________
dropout_1 (Dropout) (None, 1, 1, 4096) 0
_________________________________________________________________
conv2d_10 (Conv2D) (None, 1, 1, 10) 40970
_________________________________________________________________
flatten (Flatten) (None, 10) 0
_________________________________________________________________
activation (Activation) (None, 10) 0
=================================================================
Total params: 128,806,154
Trainable params: 128,806,154
Non-trainable params: 0
_________________________________________________________________
None
下面就可以设置损失和优化器,训练模型了:
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
train_history=model.fit(x=x_Train4D, y=y_Train_One_Hot, epochs=10, \
batch_size=100, verbose=2, validation_data=(x_Test4D, y_Test_One_Hot))
sigtem
原创文章 5获赞 17访问量 2万+
关注
私信
展开阅读全文
作者:sigtem