动手实现LSTM代码的时候,因为忘记了各种门的原理,总感觉磕磕绊绊的。在这里重新巩固一下。
普通RNN无法学习到距离当前位置较远的信息,LSTM通过修改每个隐藏层内部的结构,来解决长距离依赖问题,实现对“记忆”的控制。如下图所示:
这种特殊结构的主要思想是:将信息存储在一个个记忆细胞中,不同隐藏层的记忆细胞之间通过少量线性交互形成一条传送带(图中红线),实现信息的流动。同时引入一种“门”的结构,用来新增或删除记忆细胞中的信息,控制信息的流动。
这种结构的构成如下:
三个输入:
当前位置的特征x
,前一隐藏层的状态 a
,前一个记忆细胞c
三个输出:
当前隐藏层状态a
,当前记忆细胞c
,当前位置预测y
(一些任务中只需前两个输出)
三种门:
遗忘门:控制对历史信息c
的遗忘程度。
更新门:控制新增到当前记忆细胞中的信息。
输出门:控制记忆细胞c
中的哪些信息需要作为输出。
三种门都会输出一个介于0-1之间的数字,表示让信息通过的比例。0表示不让任何信息通过,1表示让全部信息通过。
其他:
候选值c~
:新增到当前记忆细胞c
中的候选信息
整体工作如下:
输入x
,a
和c
;
对x
和a
进行线性变换和sigmoid操作,分别得到遗忘门的值f
、更新门的值i
和输出门的值o
;
对x
和a
进行线性变换和tanh操作,得到候选值c~
;
将f
和c
进行逐元素相乘,得到保留下来的历史信息;将i
和c~
进行逐元素相乘,得到新增到记忆细胞中的信息;将两者相加得到新的记忆细胞c
;
对c
进行tanh操作,并与o
进行逐元素相乘,得到过滤后的信息a
用来输出;
如果在当前位置还需要输出预测值,则对a
进行softmax操作,输出预测值y
。
具体计算公式如图: