keras_ LSTM 层和 GRU 层

Hedva ·
更新时间:2024-11-13
· 838 次阅读

6.2.2 理解 LSTM 层和 GRU 层

参考:
https://blog.csdn.net/qq_30614345/article/details/98714874
6.2.4 小结
现在你已经学会了以下内容。
‰ 循环神经网络(RNN)的概念及其工作原理。
‰ 长短期记忆(LSTM)是什么,为什么它在长序列上的效果要好于普通 RNN。
‰ 如何使用 Keras 的 RNN 层来处理序列数据。
接下来,我们将介绍 RNN 几个更高级的功能,这可以帮你有效利用深度学习序列模型
6.2.

2 理解 LSTM 层和 GRU 层
SimpleRNN 并不是 Keras 中唯一可用的循环层,还有另外两个: LSTM 和 GRU。在实践中总会用到其中之一,因为 SimpleRNN 通常过于简化,没有实用价值。 SimpleRNN 的最大问题是,在时刻 t,理论上来说,它应该能够记住许多时间步之前见过的信息,但实际上它是不可能学
到这种长期依赖的。其原因在于梯度消失问题(vanishing gradient problem),这一效应类似于在层数较多的非循环网络(即前馈网络)中观察到的效应:随着层数的增加,网络最终变得无法训练。 Hochreiter、 Schmidhuber 和 Bengio 在 20 世纪 90 年代初研究了这一效应的理论原因 a。LSTM 层和 GRU 层都是为了解决这个问题而设计的

6.2.3 Keras 中一个 LSTM 的具体例子
代码清单 6-27 使用 Keras 中的 LSTM 层

A concrete LSTM example in Keras Now let's switch to more practical concerns: we will set up a model using a LSTM layer and train it on the IMDB data. Here's the network, similar to the one with SimpleRNN that we just presented. We only specify the output dimensionality of the LSTM layer, and leave every other argument (there are lots) to the Keras defaults. Keras has good defaults, and things will almost always "just work" without you having to spend time tuning parameters by hand. 6.2.3 Keras 中一个 LSTM 的具体例子 # 现在我们来看一个更实际的问题:使用 LSTM 层来创建一个模型,然后在 IMDB 数据上 # 训练模型(见图 6-16 和图 6-17)。 这个网络与前面介绍的 SimpleRNN 网络类似。你只需指定 # LSTM 层的输出维度,其他所有参数(有很多)都使用 Keras 默认值。 Keras 具有很好的默认值, # 无须手动调参,模型通常也能正常运行。 # 代码清单 6-27 使用 Keras 中的 LSTM 层 # 6.2.3 Keras 中一个 LSTM 的具体例子 # 现在我们来看一个更实际的问题:使用 LSTM 层来创建一个模型,然后在 IMDB 数据上 # 训练模型(见图 6-16 和图 6-17)。 这个网络与前面介绍的 SimpleRNN 网络类似。你只需指定 # LSTM 层的输出维度,其他所有参数(有很多)都使用 Keras 默认值。 Keras 具有很好的默认值, # 无须手动调参,模型通常也能正常运行。 # 代码清单 6-27 使用 Keras 中的 LSTM 层 ​ from keras.layers import LSTM ​ model = Sequential() model.add(Embedding(max_features, 32)) model.add(LSTM(32)) model.add(Dense(1, activation='sigmoid')) ​ model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc']) history = model.fit(input_train, y_train, epochs=10, batch_size=128, validation_split=0.2) Train on 20000 samples, validate on 5000 samples Epoch 1/10 20000/20000 [==============================] - 108s - loss: 0.5038 - acc: 0.7574 - val_loss: 0.3853 - val_acc: 0.8346 Epoch 2/10 20000/20000 [==============================] - 108s - loss: 0.2917 - acc: 0.8866 - val_loss: 0.3020 - val_acc: 0.8794 Epoch 3/10 20000/20000 [==============================] - 107s - loss: 0.2305 - acc: 0.9105 - val_loss: 0.3125 - val_acc: 0.8688 Epoch 4/10 20000/20000 [==============================] - 107s - loss: 0.2033 - acc: 0.9261 - val_loss: 0.4013 - val_acc: 0.8574 Epoch 5/10 20000/20000 [==============================] - 107s - loss: 0.1749 - acc: 0.9385 - val_loss: 0.3273 - val_acc: 0.8912 Epoch 6/10 20000/20000 [==============================] - 107s - loss: 0.1543 - acc: 0.9457 - val_loss: 0.3505 - val_acc: 0.8774 Epoch 7/10 20000/20000 [==============================] - 107s - loss: 0.1417 - acc: 0.9493 - val_loss: 0.4485 - val_acc: 0.8396 Epoch 8/10 20000/20000 [==============================] - 106s - loss: 0.1331 - acc: 0.9522 - val_loss: 0.3242 - val_acc: 0.8928 Epoch 9/10 20000/20000 [==============================] - 106s - loss: 0.1147 - acc: 0.9618 - val_loss: 0.4216 - val_acc: 0.8746 Epoch 10/10 20000/20000 [==============================] - 106s - loss: 0.1092 - acc: 0.9628 - val_loss: 0.3972 - val_acc: 0.8758 这一次,验证精度达到了 89%。还不错,肯定比 SimpleRNN 网络好多了,这主要是因为 # LSTM 受梯度消失问题的影响要小得多。这个结果也比第 3 章的全连接网络略好,虽然使用的172  第 6 章 深度学习用于文本和序列 # 数据量比第 3 章要少。此处在 500 个时间步之后将序列截断,而在第 3 章是读取整个序列。 # 但对于一种计算量如此之大的方法而言,这个结果也说不上是突破性的。为什么 LSTM 不 # 能表现得更好?一个原因是你没有花力气来调节超参数,比如嵌入维度或 LSTM 输出维度。另 # 一个原因可能是缺少正则化。但说实话,主要原因在于,适用于评论分析全局的长期性结构(这 # 正是 LSTM 所擅长的),对情感分析问题帮助不大。对于这样的基本问题,观察每条评论中出现 # 了哪些词及其出现频率就可以很好地解决。这也正是第一个全连接方法的做法。但还有更加困 # 难的自然语言处理问题,特别是问答和机器翻译,这时 LSTM 的优势就明显了。 acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] ​ epochs = range(len(acc)) ​ plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() ​ plt.figure() ​ plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() ​ plt.show() ​ # 这一次,验证精度达到了 89%。还不错,肯定比 SimpleRNN 网络好多了,这主要是因为 # LSTM 受梯度消失问题的影响要小得多。这个结果也比第 3 章的全连接网络略好,虽然使用的172  第 6 章 深度学习用于文本和序列 # 数据量比第 3 章要少。此处在 500 个时间步之后将序列截断,而在第 3 章是读取整个序列。 # 但对于一种计算量如此之大的方法而言,这个结果也说不上是突破性的。为什么 LSTM 不 # 能表现得更好?一个原因是你没有花力气来调节超参数,比如嵌入维度或 LSTM 输出维度。另 # 一个原因可能是缺少正则化。但说实话,主要原因在于,适用于评论分析全局的长期性结构(这 # 正是 LSTM 所擅长的),对情感分析问题帮助不大。对于这样的基本问题,观察每条评论中出现 # 了哪些词及其出现频率就可以很好地解决。这也正是第一个全连接方法的做法。但还有更加困 # 难的自然语言处理问题,特别是问答和机器翻译,这时 LSTM 的优势就明显了。

在这里插入图片描述


作者:御剑归一



keras lstm

需要 登录 后方可回复, 如果你还没有账号请 注册新账号
相关文章