Home Blogs Readings Notes Jupyter Seafile

理解并推导LSTM

03 April 2018

Table of Content

===========

LSTM, Long Short Memory Networks.其实质是RNN的变种.其关键点是对隐含层之间状态迁移计算的精细化.

要理解LSTM首先必须得非常清楚的理解了RNN才能更好的继续,因为LSTM的诞生正是为了解决 RNN的局限性,梯度消失(或者爆炸).那么LSTM对这个问题是怎么解决的?

首先,RNN的局限的根源是对任何历史信息和当前数据信息都不加区分的直接传递到后面的序列节点中。所以LSTM 的解决方式是对历史信息和输入信息的的传递加上一个权重开关,来控制信息的传递过程。

其次,RNN的计算过程中其实是没有用到上次计算的输出的。而LSTM则综合考虑历史的输出和当前的输入作为状态 迁移的一部分信息进行处理。

最后,LSTM创造性的将模型结构概念和人的直觉认知更进一步的深化结合。人会遗忘也会长期 的实践过程反复强化一些有用的信息。刚发生的事情肯定印象比较深。那么这两个方面作用的过程 最终形成了人的经验。类比于LSTM结构,加上遗忘门(forget gate)给历史信息和当前信息一个 衰减的过程。在根据上一步状态和输入计算当前状态时,再综合考虑进上一步的输出。最终当前状态的信息来源 包含三个方面:

前向计算

门函数

这个合并计算的地方其实并不是啥新东西,在RNN计算的时候也可以,不过因为 RNN就一个类似公式,而这实际上有4个。这样写好处就是方便。

状态更新及计算输出

损失函数

反向求导

需要特殊注意的是之所以

其原因是用了截断的BPTT,只考虑前一步,所以不会有递归。另外其实际计算时

则是因为的梯度贡献不仅来自于当前步骤的计算还有前一步的结果。

小结

在lstm的整个推导过程中,耗时最久的是对cell state的求导计算。前向过程在colah的博客中解释的 很清楚。但是关于截断的BPTT和RTRL(real time recurrent learning)这两个在lstm求导中的应用 却不得不翻论文,结果却是翻了好久也没有几个人能真正把这个以比较好理解的方式明确说出来。所以结果就是 我知道这样做,但是却并不知道为什么。留下了疑点。现在虽然有所认识,但依然不是很明确。尽管如此还是 实现了一个100行不到的lstm。还是值得欣慰的。

Refs

-->
Fork me on GitHub