Inception网络 运行在Cifar10 测试集87.88% Tensorflow 2.1 小白从代码实践中 理解

Nona ·
更新时间:2024-09-20
· 672 次阅读

环境
tensorflow 2.1
最好用GPU

模型
Inception

训练数据
Cifar10 或者 Cifar 100

训练集上准确率:93%左右
验证集上准确率:88%左右
测试集上准确率:87.88%
训练时间在GPU上:一小时多
权重大小:3.98 MB

原理介绍

图像中突出部分的大小差别很大。例如,狗的图像可以是以下任意情况。每张图像中狗所占区域都是不同的。

在这里插入图片描述
从左到右:狗占据图像的区域依次减小(图源:https://unsplash.com/)。

由于信息位置的巨大差异,为卷积操作选择合适的卷积核大小就比较困难。信息分布更全局性的图像偏好较大的卷积核,信息分布比较局部的图像偏好较小的卷积核。
非常深的网络更容易过拟合。将梯度更新传输到整个网络是很困难的。
简单地堆叠较大的卷积层非常消耗计算资源。

深度神经网络需要耗费大量计算资源。为了降低算力成本,在 3x3 和 5x5 卷积层之前添加额外的 1x1 卷积层,来限制输入信道的数量。尽管添加额外的卷积操作似乎是反直觉的,但是 1x1 卷积比 5x5 卷积要廉价很多,而且输入信道数量减少也有利于降低算力成本。不过一定要注意,1x1 卷积是在最大池化层之后,而不是之前。
在这里插入图片描述

完整代码

import tensorflow as tf import tensorflow.keras as keras import tensorflow.keras.layers as layers import image_augument.image_augment as image_augment import time as time import tensorflow.keras.preprocessing.image as image import matplotlib.pyplot as plt import os from keras_applications import densenet def conv2d_bn(x, filters, num_row, num_col, padding='same', strides=(1, 1), name=None): if name is not None: bn_name = name + '_bn' conv_name = name + '_conv' else: bn_name = None conv_name = None bn_axis = 3 x = layers.Conv2D( filters, (num_row, num_col),strides=strides,padding=padding, use_bias=False, name=conv_name)(x) x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x) x = layers.Activation('relu', name=name)(x) return x def my_densenet(): channel_axis = 3 inputs = keras.Input(shape=(32, 32, 3), name='img') # mixed 0: 35 x 35 x 256 x = inputs branch1x1 = conv2d_bn(x, 16, 1, 1) branch5x5 = conv2d_bn(x, 16, 5, 5) branch3x3dbl = conv2d_bn(x, 16, 3, 3) branch_pool = layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) x = layers.concatenate([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=channel_axis, name='mixed0') # mixed 1: 35 x 35 x 288 branch1x1 = conv2d_bn(x, 32, 1, 1) branch5x5 = conv2d_bn(x, 24, 1, 1) branch5x5 = conv2d_bn(branch5x5, 32, 5, 5) branch3x3dbl = conv2d_bn(x, 32, 1, 1) branch3x3dbl = conv2d_bn(branch3x3dbl, 54, 3, 3) branch_pool = layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) branch_pool = conv2d_bn(branch_pool, 32, 1, 1) x = layers.concatenate( [branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=channel_axis, name='mixed1') # mixed 2: 35 x 35 x 288 branch1x1 = conv2d_bn(x, 32, 1, 1) branch5x5 = conv2d_bn(x, 24, 1, 1) branch5x5 = conv2d_bn(branch5x5, 32, 5, 5) branch3x3dbl = conv2d_bn(x, 32, 1, 1) branch3x3dbl = conv2d_bn(branch3x3dbl, 64, 3, 3) branch3x3dbl = conv2d_bn(branch3x3dbl, 64, 3, 3) branch_pool = layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) branch_pool = conv2d_bn(branch_pool, 32, 1, 1) x = layers.concatenate([branch1x1, branch5x5, branch3x3dbl, branch_pool],axis=channel_axis, name='mixed2') x = layers.MaxPooling2D(2, strides=2, name='1_pool')(x) # mixed 3: 17 x 17 x 768 branch3x3 = conv2d_bn(x, 292, 3, 3, strides=(2, 2), padding='valid') branch3x3dbl = conv2d_bn(x, 32, 1, 1) branch3x3dbl = conv2d_bn(branch3x3dbl, 64, 3, 3) branch3x3dbl = conv2d_bn( branch3x3dbl, 64, 3, 3, strides=(2, 2), padding='valid') branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) x = layers.concatenate( [branch3x3, branch3x3dbl, branch_pool], axis=channel_axis, name='mixed3') x = layers.MaxPooling2D(2, strides=2, name='2_pool')(x) # mixed 4: 17 x 17 x 768 branch1x1 = conv2d_bn(x, 94, 1, 1) branch7x7 = conv2d_bn(x, 64, 1, 1) branch7x7 = conv2d_bn(branch7x7, 64, 1, 7) branch7x7 = conv2d_bn(branch7x7, 94, 7, 1) branch7x7dbl = conv2d_bn(x, 64, 1, 1) branch7x7dbl = conv2d_bn(branch7x7dbl, 64, 7, 1) branch7x7dbl = conv2d_bn(branch7x7dbl, 64, 1, 7) branch7x7dbl = conv2d_bn(branch7x7dbl, 64, 7, 1) branch7x7dbl = conv2d_bn(branch7x7dbl, 94, 1, 7) branch_pool = layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(x) branch_pool = conv2d_bn(branch_pool, 94, 1, 1) x = layers.concatenate( [branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=channel_axis, name='mixed4') x = layers.GlobalAveragePooling2D(name='avg_pool')(x) x = layers.Dense(10, activation='softmax', name='fc')(x) model = keras.Model(inputs, x, name='my_ResNet101') return model def my_model(): #denseNet = keras.applications.DenseNet121(input_shape=(32,32,3), include_top=True, weights=None, classes=10) denseNet = my_densenet() denseNet.compile(optimizer=keras.optimizers.Adam(), loss=keras.losses.SparseCategoricalCrossentropy(), #metrics=['accuracy']) metrics=[keras.metrics.SparseCategoricalAccuracy()]) denseNet.summary() #keras.utils.plot_model(denseNet, 'my_ResNet101.png', show_shapes=True) return denseNet current_max_loss = 9999 weight_file='./weights7_2/model.h5' log_file = 'logs7_2' def train_my_model(deep_model): (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() train_datagen = image.ImageDataGenerator( rescale=1 / 255, rotation_range=40, # 角度值,0-180.表示图像随机旋转的角度范围 width_shift_range=0.2, # 平移比例,下同 height_shift_range=0.2, shear_range=0.2, # 随机错切变换角度 zoom_range=0.2, # 随即缩放比例 horizontal_flip=True, # 随机将一半图像水平翻转 fill_mode='nearest' # 填充新创建像素的方法 ) test_datagen = image.ImageDataGenerator(rescale=1 / 255) validation_datagen = image.ImageDataGenerator(rescale=1 / 255) train_generator = train_datagen.flow(x_train[:45000], y_train[:45000], batch_size=128) # train_generator = train_datagen.flow(x_train, y_train, batch_size=128) validation_generator = validation_datagen.flow(x_train[45000:], y_train[45000:], batch_size=128) test_generator = test_datagen.flow(x_test, y_test, batch_size=128) begin_time = time.time() if os.path.isfile(weight_file): print('load weight') deep_model.load_weights(weight_file) def save_weight(epoch, logs): global current_max_loss if(logs['val_loss'] is not None and logs['val_loss']< current_max_loss): current_max_loss = logs['val_loss'] print('save_weight', epoch, current_max_loss) deep_model.save_weights(weight_file) batch_print_callback = keras.callbacks.LambdaCallback( on_epoch_end=save_weight ) callbacks = [ tf.keras.callbacks.EarlyStopping(patience=4, monitor='loss'), batch_print_callback, # keras.callbacks.ModelCheckpoint('./weights/model.h5', save_best_only=True), tf.keras.callbacks.TensorBoard(log_dir=log_file) ] print(train_generator[0][0].shape) history = deep_model.fit_generator(train_generator, steps_per_epoch=351, epochs=200, callbacks=callbacks, validation_data=validation_generator, validation_steps=39, initial_epoch = 0) result = deep_model.evaluate_generator(test_generator, verbose=2) print(result) print('time', time.time() - begin_time) def show_result(history): plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.plot(history.history['sparse_categorical_accuracy']) plt.plot(history.history['val_sparse_categorical_accuracy']) plt.legend(['loss', 'val_loss', 'sparse_categorical_accuracy', 'val_sparse_categorical_accuracy'], loc='upper left') plt.show() print(history) show_result(history) def test_module(deep_model): (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() test_datagen = image.ImageDataGenerator(rescale=1 / 255) test_generator = test_datagen.flow(x_test, y_test, batch_size=128) begin_time = time.time() if os.path.isfile(weight_file): print('load weight') deep_model.load_weights(weight_file) result = deep_model.evaluate_generator(test_generator, verbose=2) print(result) print('time', time.time() - begin_time) def predict_module(deep_model): x_train, y_train, x_test, y_test = image_augment.get_all_train_data(False) import numpy as np if os.path.isfile(weight_file): print('load weight') deep_model.load_weights(weight_file) print(y_test[0:20]) for i in range(20): img = x_test[i][np.newaxis, :]/255 y_ = deep_model.predict(img) v = np.argmax(y_) print(v, y_test[i]) if __name__ == '__main__': #my_densenet() deep_model = my_model() #train_my_model(deep_model) test_module(deep_model) #predict_module(deep_model)

测试集上运行结果

79/79 - 4s - loss: 0.4228 - sparse_categorical_accuracy: 0.8788 [0.4228407618931577, 0.8788] time 3.7423994541168213

参数大小

Total params: 986,920
Trainable params: 983,428
Non-trainable params: 3,492

参考链接
https://baijiahao.baidu.com/s?id=1601882944953788623&wfr=spider&for=pc


作者:茫茫人海一粒沙



运行 测试集 tensorflow 测试

需要 登录 后方可回复, 如果你还没有账号请 注册新账号
相关文章