Keras版VGG11识别MNIST手写数字

Irina ·
更新时间:2024-11-10
· 507 次阅读

VGG对硬件要求较AlexNet高,一般CPU跑起来很慢,最好用GPU。

首先引入相关库

from tensorflow.keras.models import Sequential, Model from tensorflow.keras.layers import ZeroPadding2D, Convolution2D, MaxPooling2D, Dropout, Activation, Flatten import numpy as np from keras.utils import np_utils from tensorflow.keras.applications.vgg19 import preprocess_input import cv2 from keras.datasets import mnist

下面引入MNIST图像数据集。标准VGG的输入数据是224*224大小的图像,对于灰度图像4D张量是(None, 224, 224, 1),彩色图像是(None, 224, 224, 3)。对于彩色图像还可以用preprocess_input做简单的预处理。MNIST原始图像是28*28大小的,所以要调整图像尺寸。标签要转换成one-hot形式,以便采用分类模型。

(x_Train,y_Train),(x_Test,y_Test)= mnist.load_data() x_Train=x_Train.reshape(x_Train.shape[0],28,28).astype('float32') x_Test=x_Test.reshape(x_Test.shape[0],28,28).astype('float32') x_Train_resized=np.array([cv2.resize(i,(224,224)) for i in x_Train]) x_Test_resized=np.array([cv2.resize(i,(224,224)) for i in x_Test]) x_Train4D=x_Train_resized.reshape(-1,224,224,1) x_Test4D=x_Test_resized.reshape(-1,224,224,1) y_Train_One_Hot = np_utils.to_categorical(y_Train) y_Test_One_Hot = np_utils.to_categorical(y_Test)

下面开始构建VGG11神经网络:

model = Sequential() model.add(ZeroPadding2D((1,1),input_shape=(224,224, 1))) model.add(Convolution2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D((2,2), strides=(2,2))) model.add(ZeroPadding2D((1,1))) model.add(Convolution2D(128, (3, 3), activation='relu')) model.add(MaxPooling2D((2,2), strides=(2,2))) model.add(ZeroPadding2D((1,1))) model.add(Convolution2D(256, (3, 3), activation='relu')) model.add(ZeroPadding2D((1,1))) model.add(Convolution2D(256, (3, 3), activation='relu')) model.add(MaxPooling2D((2,2), strides=(2,2))) model.add(ZeroPadding2D((1,1))) model.add(Convolution2D(512, (3, 3), activation='relu')) model.add(ZeroPadding2D((1,1))) model.add(Convolution2D(512, (3, 3), activation='relu')) model.add(MaxPooling2D((2,2), strides=(2,2))) model.add(ZeroPadding2D((1,1))) model.add(Convolution2D(512, (3, 3), activation='relu')) model.add(ZeroPadding2D((1,1))) model.add(Convolution2D(512, (3, 3), activation='relu')) model.add(MaxPooling2D((2,2), strides=(2,2))) model.add(Convolution2D(4096, (7, 7), activation='relu')) model.add(Dropout(0.5)) model.add(Convolution2D(4096, (1, 1), activation='relu')) model.add(Dropout(0.5)) model.add(Convolution2D(10, (1, 1))) model.add(Flatten()) model.add(Activation('softmax'))

VGG11神经网络构建好了,此时如果想看一下模型的参数情况,可以运行

print(model.summary()) _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= zero_padding2d (ZeroPadding2 (None, 226, 226, 1) 0 _________________________________________________________________ conv2d (Conv2D) (None, 224, 224, 64) 640 _________________________________________________________________ max_pooling2d (MaxPooling2D) (None, 112, 112, 64) 0 _________________________________________________________________ zero_padding2d_1 (ZeroPaddin (None, 114, 114, 64) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 112, 112, 128) 73856 _________________________________________________________________ max_pooling2d_1 (MaxPooling2 (None, 56, 56, 128) 0 _________________________________________________________________ zero_padding2d_2 (ZeroPaddin (None, 58, 58, 128) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 56, 56, 256) 295168 _________________________________________________________________ zero_padding2d_3 (ZeroPaddin (None, 58, 58, 256) 0 _________________________________________________________________ conv2d_3 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ max_pooling2d_2 (MaxPooling2 (None, 28, 28, 256) 0 _________________________________________________________________ zero_padding2d_4 (ZeroPaddin (None, 30, 30, 256) 0 _________________________________________________________________ conv2d_4 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ zero_padding2d_5 (ZeroPaddin (None, 30, 30, 512) 0 _________________________________________________________________ conv2d_5 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ max_pooling2d_3 (MaxPooling2 (None, 14, 14, 512) 0 _________________________________________________________________ zero_padding2d_6 (ZeroPaddin (None, 16, 16, 512) 0 _________________________________________________________________ conv2d_6 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ zero_padding2d_7 (ZeroPaddin (None, 16, 16, 512) 0 _________________________________________________________________ conv2d_7 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ max_pooling2d_4 (MaxPooling2 (None, 7, 7, 512) 0 _________________________________________________________________ conv2d_8 (Conv2D) (None, 1, 1, 4096) 102764544 _________________________________________________________________ dropout (Dropout) (None, 1, 1, 4096) 0 _________________________________________________________________ conv2d_9 (Conv2D) (None, 1, 1, 4096) 16781312 _________________________________________________________________ dropout_1 (Dropout) (None, 1, 1, 4096) 0 _________________________________________________________________ conv2d_10 (Conv2D) (None, 1, 1, 10) 40970 _________________________________________________________________ flatten (Flatten) (None, 10) 0 _________________________________________________________________ activation (Activation) (None, 10) 0 ================================================================= Total params: 128,806,154 Trainable params: 128,806,154 Non-trainable params: 0 _________________________________________________________________ None

下面就可以设置损失和优化器,训练模型了:

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) train_history=model.fit(x=x_Train4D, y=y_Train_One_Hot, epochs=10, \ batch_size=100, verbose=2, validation_data=(x_Test4D, y_Test_One_Hot)) sigtem 原创文章 5获赞 17访问量 2万+ 关注 私信 展开阅读全文
作者:sigtem



mnist keras vgg

需要 登录 后方可回复, 如果你还没有账号请 注册新账号
相关文章