RNN通过时间反向传播
循环神经网络的梯度分析
循环神经网络(RNN)的前向传播满足: \[ \begin{aligned} &\mathbf{h}_t = \phi(\mathbf{x}_t\mathbf{W}_{xh} + \mathbf{h}_{t-1}\mathbf{W}_{hh} + \mathbf{b}_h)\\ &\mathbf{o}_t = \mathbf{h}_t\mathbf{W}_{o} + \mathbf{b}_o \end{aligned} \] 即,\(\mathbf{h}_t\)是关于\(\mathbf{x}_t, \mathbf{h}_{t-1}\)的函数,令 \[ \mathbf{W}_h = \left[ \begin{matrix} \mathbf{W}_{xh}\\ \mathbf{W}_{hh}\\ \mathbf{b}_h \end{matrix} \right] \] 则, \[ \mathbf{h}_t = [ \begin{matrix} \mathbf{x}_t&\mathbf{h}_{t-1}&\mathbf{1} \end{matrix}]\cdot \mathbf{W}_h \]
可以写成: \[ \begin{aligned} \mathbf{h}_t = f(\mathbf{x}_t, \mathbf{h}_{t-1}, \mathbf{W}_h)\\ \mathbf{o}_t = g(\mathbf{h}_t, \mathbf{W}_o) \end{aligned} \] 其中\(f\)和\(g\)分别是隐藏层和输出层的变换。
损失函数: \[ L(\mathbf{x}_1, \cdots, \mathbf{x}_T, \mathbf{y}_1, \cdots, \mathbf{y}_T, \mathbf{W}_h, \mathbf{W}_o) = \frac{1}{T}\sum_{t=1}^T l(\mathbf{y}_t, \mathbf{o}_t) \] 对于反向传播: \[ \begin{aligned} \frac{\partial L}{\partial \mathbf{W}_h} &= \frac{1}{T}\sum_{t=1}^T\frac{\partial l(\mathbf{y_t}, \mathbf{o}_t)}{\partial \mathbf{W}_h}\\ &= \frac{1}{T}\sum_{t=1}^T\frac{\partial l(\mathbf{y}_t, \mathbf{o}_t)}{\partial \mathbf{o}_t}\frac{\partial g(\mathbf{h}_t, \mathbf{W}_o)}{\partial \mathbf{h}_t}\frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_h} \end{aligned} \] 该式子中的第一项和第二项很容易计算,而第三项\(\partial \mathbf{h}_t / \partial \mathbf{W}_t\)的计算则很麻烦,因为\(\mathbf{h}_t = f(\mathbf{x}_t, \mathbf{h}_{t-1}, \mathbf{W}_h)\),而\(\mathbf{h}_{t-1}\)又与\(\mathbf{W}_h\)和\(\mathbf{h}_{t-2}\)有关,依此类推,我们需要循环地计算参数\(\mathbf{W}_h\)对\(\mathbf{h}_t\)的影响,使用链式法则: \[ \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_h} = \frac{\partial f(\mathbf{x}_t, \mathbf{h}_{t-1}, \mathbf{W}_h)}{\partial \mathbf{W}_h} + \frac{\partial f(\mathbf{x}_t, \mathbf{h}_{t-1}, \mathbf{W}_h)}{\partial \mathbf{h}_{t-1}}\frac{\partial \mathbf{h}_{t-1}}{\partial \mathbf{W}_h} + \frac{\partial f(\mathbf{x}_t, \mathbf{h}_{t-1}, \mathbf{W}_h)}{\partial \mathbf{h}_{t-1}}\frac{\partial f(\mathbf{x}_{t-1}, \mathbf{h}_{t-2}, \mathbf{W}_h)}{\partial \mathbf{h}_{t-2}}\frac{\partial \mathbf{h}_{t-2}}{\partial \mathbf{W}_h} + \cdots\cdots \] 该式子中第\(i+1\)项可以写为: \[ \begin{aligned} &\frac{\partial f(\mathbf{x}_t, \mathbf{h}_{t-1}, \mathbf{W}_h)}{\partial \mathbf{h}_{t-1}}\times \cdots \times \frac{f(\mathbf{x}_{t-i+1}, \mathbf{h}_{t-i}, \mathbf{W}_h)}{\partial \mathbf{W}_h}\\ &= \Bigg(\prod_{j=i+1}^{t}\frac{\partial f(\mathbf{x}_j, \mathbf{h}_{j-1}, \mathbf{W}_h)}{\partial \mathbf{h}_{j-1}}\Bigg)\frac{\partial f(\mathbf{x}_{t-i+1}, \mathbf{h}_{t-i}, \mathbf{W}_h)}{\partial \mathbf{W}_h} \end{aligned} \] 则 \[ \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_h} = \frac{\partial f(\mathbf{x}_t, \mathbf{h}_{t-1}, \mathbf{W}_h)}{\partial \mathbf{W}_h} + \sum_{i=1}^{t-1}\Bigg(\prod_{j=i+1}^{t}\frac{\partial f(\mathbf{x}_j, \mathbf{h}_{j-1}, \mathbf{W}_h)}{\partial \mathbf{h}_{j-1}}\Bigg)\frac{\partial f(\mathbf{x}_{t-i+1}, \mathbf{h}_{t-i}, \mathbf{W}_h)}{\partial \mathbf{W}_h} \] 虽然我们可以用链式法则递归地计算\(\partial \mathbf{h}_t / \partial \mathbf{W}_t\),但当\(t\)很大时这个链就会变得很长,计算量很大。
循环神经网络的反向传播
完全计算
对于一个很长的序列,如果我们使用完整的计算图进行反向传播更新参数,这样的计算会非常缓慢,并且可能会发生梯度爆炸,因为初始条件的微小变化就会导致结果发生不成比例的变化。这对于我们想要估计的模型而言是非常不可取的。因此,在实践中,这种方法几乎从未使用。
截断时间步
截断时间步(Truncated Backpropagation Through Time,TBPTT),我们可以在\(\tau\)步后截断\(\partial \mathbf{h}_t / \partial \mathbf{W}_t\)式子中的求和计算: \[ \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_h} = \frac{\partial f(\mathbf{x}_t, \mathbf{h}_{t-1}, \mathbf{W}_h)}{\partial \mathbf{W}_h} + \sum_{i=1}^{\tau}\Bigg(\prod_{j=i+1}^{t}\frac{\partial f(\mathbf{x}_j, \mathbf{h}_{j-1}, \mathbf{W}_h)}{\partial \mathbf{h}_{j-1}}\Bigg)\frac{\partial f(\mathbf{x}_{t-i+1}, \mathbf{h}_{t-i}, \mathbf{W}_h)}{\partial \mathbf{W}_h} \] 即将求和终止为\(\partial \mathbf{h}_{t-\tau}/\partial \mathbf{W}_h\),在实践中,这种方式工作得很好,它通常被称为截断的通过时间反向传播。这样做导致该模型主要侧重于短期影响,而不是长期影响。这在现实中是可取的,因为它会将估值偏向更简单和更稳定的模型。
截断时间步的实现:
与直接将长序列分成很多个固定短序列不同,截断时间步会按顺序分成短序列,上一个短序列输入模型得到输出和最终隐藏状态后,会将隐藏状态传递给下一个短序列模型输入的隐藏状态,但这个隐藏状态不会在两个短序列之间保留梯度。
1 |
|
随机截断
随机截断时间步(Stochastic Truncated Backpropagation Through Time, STBPTT),在传统截断时间步反向传播(TBPTT)的基础上引入随机性,以增强模型的泛化能力和训练效率。
比较策略
第一行采用随机截断,方法是将文本划分为不同长度的片段。
第二行采用常规截断,方法是将文本分解为相同长度的子序列。
第三行采用通过时间的完全反向传播,结果是产生在计算上不可行的表达式。
虽然随截断在理论上具有吸引力,但很可能是由于多种因素在实践中并不比常规截断更好。首先,在对过去若干个时间步经过反向传播后,观测结果足以捕获实际的依赖关系。其次,增加的方差抵消了时间步数越多梯度越精确的事实。第三,我们真正想要的是只有短范围交互的模型。因此,模型需要的正是截断的通过时间反向传播方法所具备的轻度正则化效果。