9.9 KiB
循环神经网络RNN
循环神经网络(Recurrent Neural Network,简称RNN)是一类专门用于处理序列数据的神经网络模型。与传统的前馈神经网络不同,RNN具有“记忆”功能,能够捕捉数据序列中的时间依赖关系。
基本结构
RNN的核心在于它的循环结构,这个结构使得信息可以沿着时间步流动。一个典型的RNN单元在时间步 t
接收输入向量 x_t
和前一时刻的隐藏状态 $h_{t-1}$,然后计算当前时刻的隐藏状态 $h_t$。这种循环过程允许模型利用之前的状态信息来影响当前的预测。
隐藏状态的更新
隐藏状态更新通常通过如下公式实现:
h_t = f(W_{xh} \cdot x_t + W_{hh} \cdot h_{t-1} + b_h)
其中:
h_t
表示时间步t
的隐藏状态(所有隐藏层神经元激活值的集合。)。x_t
是时间步t
的输入向量。W_{xh}
是输入到隐藏状态的权重矩阵。W_{hh}
是隐藏状态之间的递归连接权重矩阵。b_h
是偏置项。f
是激活函数,通常会选择非线性函数如tanh或ReLU,以引入非线性变换。
在这种更新过程中,当前的隐藏状态 h_t
同时依赖于当前的输入 x_t
和之前的隐藏状态 $h_{t-1}$,这使得RNN能够捕捉长时间序列中的上下文关系。
输出层
有时RNN还会在每个时间步产生输出,输出计算方式通常为:
y_t = g(W_{hy} \cdot h_t + b_y)
其中:
y_t
是时间步t
的输出。W_{hy}
是隐藏状态到输出的权重矩阵。b_y
是输出层的偏置项。g
是输出层激活函数(例如softmax用于分类任务)。
困惑度
假设我们有一个测试序列,其中包含 3 个单词,模型对每个单词的预测概率分别为:
P(w_1) = 0.5
P(w_2|w_1) = 0.2
P(w_3|w_1, w_2) = 0.1
根据困惑度的公式:
\text{Perplexity} = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | \text{context})\right)
当模型对每个单词都能百分之百预测(即概率为1),则平均交叉熵为0,困惑度为 $\exp(0)=1$。这表示模型没有任何不确定性,是理想状态。
我们这里 $N=3$。下面是具体的计算步骤:
- 计算每个单词的对数概率
\log P(w_1) = \log(0.5) \approx -0.6931
\log P(w_2|w_1) = \log(0.2) \approx -1.6094
\log P(w_3|w_1, w_2) = \log(0.1) \approx -2.3026
- 求和并求平均
将这些对数值相加:
\sum_{i=1}^{3} \log P(w_i|\text{context}) = -0.6931 - 1.6094 - 2.3026 \approx -4.6051
然后求平均:
\text{平均对数概率} = \frac{-4.6051}{3} \approx -1.5350
- 计算困惑度
取负值再求指数:
\text{Perplexity} = \exp\left(1.5350\right) \approx 4.64
训练过程与挑战
整体训练流程可以总结为下面几个步骤,每个 epoch 都会重复这些步骤:
- 前向传播
- 对于一个完整的句子(或者一个批次中的多个句子),模型按顺序处理所有时间步,生成每个时间步的输出。
- 比如,对于句子“我 爱 编程”,模型会依次处理“我”、“爱”、“编程”,得到对应的输出(例如每个时间步预测下一个词的概率分布)。
- 计算损失
- 将模型在所有时间步的输出与真实目标序列(也就是每个时间步的正确答案)进行比较,计算整体损失。
- 损失通常是所有时间步损失的总和或平均值,例如均方误差或交叉熵损失。
- 反向传播(BPTT)
- 对整个句子进行反向传播,即通过时间(Back Propagation Through Time,BPTT)计算所有时间步的梯度。
- 这一步会利用链式法则,把整个序列中各个时间步的梯度累积起来,形成每个参数的总梯度。
- 参数更新
- 使用优化器(如 Adam、SGD 等)根据计算得到的梯度更新模型参数。
- 重复整个过程
- 以上步骤构成了一个训练迭代周期(一个 epoch),在一个 epoch 中,所有训练样本都会被送入模型进行训练。
- 然后在下一个 epoch 中,再次重复整个流程,直到达到预设的 epoch 数或满足其他停止条件。
在训练过程中,RNN通过反向传播算法(具体为“反向传播通过时间”(BPTT))来更新参数。然而,由于梯度在长序列上传播时可能出现梯度消失或梯度爆炸问题,使得RNN在捕捉长程依赖关系时面临挑战。为此,后来发展出了如长短时记忆网络(LSTM)和门控循环单元(GRU)等改进模型,它们在结构上增加了门控机制,有效缓解了这一问题。
门控循环单元GRU
GRU(Gated Recurrent Unit,门控循环单元)是一种常用的循环神经网络变种,旨在解决标准 RNN 中梯度消失或梯度爆炸的问题,同时比 LSTM 结构更简单。
基本结构
GRU 通过两个门(gate)来控制信息的流动:
- 更新门 $z_t$:
控制当前隐藏状态需要保留多少来自过去的信息以及引入多少新的信息。 - 重置门 $r_t$:
决定如何结合新输入和过去的记忆,尤其是在产生候选隐藏状态时。
另外,GRU 计算一个候选隐藏状态 $\tilde{h}_t$,并结合更新门 z_t
的信息,更新最终的隐藏状态 $h_t$。
隐藏状态更新公式
对于每个时间步 $t$,GRU 的计算过程通常包括以下步骤:
-
更新门 $z_t$
z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z)
其中:
x_t
是当前时间步的输入;h_{t-1}
是上一时刻的隐藏状态;W_z
和U_z
是权重矩阵;b_z
是偏置向量;\sigma(\cdot)
是 sigmoid 函数,用于将输出限制在[0, 1]
区间。
-
重置门 $r_t$
r_t = \sigma(W_{xr} x_t + W_{hr} h_{t-1} + b_r)
其中参数意义与更新门类似,重置门决定忘记多少过去的信息。
-
候选隐藏状态 $\tilde{h}_t$
\tilde{h}_t = \tanh(W_{xh} x_t + W_{hh} (r_t \odot h_{t-1}) + b_h)
这里:
r_t \odot h_{t-1}
表示重置门r_t
和上一时刻隐藏状态的逐元素相乘(Hadamard 乘积),用以调制历史信息的影响;\tanh(\cdot)
用来生成候选隐藏状态,将输出限制在 $[-1, 1]$。
-
最终隐藏状态 $h_t$
GRU 结合更新门和候选隐藏状态更新最终隐藏状态:h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t.
这表明更新门
z_t
决定了新信息\tilde{h}_t
与旧信息h_{t-1}
的比例。
公式
GRU 更新公式如下:
\begin{aligned}
z_t &= \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z), \\
r_t &= \sigma(W_{xr} x_t + W_{hr} h_{t-1} + b_r), \\
\tilde{h}_t &= \tanh(W_{xh} x_t + W_{hh}(r_t \odot h_{t-1}) + b_h), \\
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t.
\end{aligned}
长短时记忆网络LSTM
LSTM 是一种常用的循环神经网络变种,专门为解决标准 RNN 中的梯度消失问题而设计。它通过引入额外的“记忆单元”和多个门控机制,有效地控制信息的保存、遗忘和输出,从而捕捉长距离的依赖关系。
基本结构
LSTM 的核心在于其“细胞状态”(cell state),这是一个贯穿整个序列传递的信息流,同时有三个主要的门(gate)来控制细胞状态的更新过程:
-
遗忘门 $f_t$
决定当前时间步需要遗忘多少之前的记忆信息。 -
输入门 $i_t$
决定当前时间步有多少新的信息写入细胞状态。 -
输出门 $o_t$
决定当前时间步从细胞状态中输出多少信息作为隐藏状态。
此外,还引入了一个候选细胞状态 \tilde{c}_t
用于更新细胞状态。
隐藏状态更新公式
对于每个时间步 $t$,LSTM 的更新过程通常可以写为以下公式(所有权重矩阵用 W
和 U
表示,各门的偏置为 $b$):
\begin{aligned}
\textbf{遗忘门:}\quad f_t &= \sigma\Big(W_{xf}\, x_t + W_{hf}\, h_{t-1} + b_f\Big), \\
\textbf{输入门:}\quad i_t &= \sigma\Big(W_{xi}\, x_t + W_{hi}\, h_{t-1} + b_i\Big), \\
\textbf{输出门:}\quad o_t &= \sigma\Big(W_{xo}\, x_t + W_{ho}\, h_{t-1} + b_o\Big), \\
\\
\textbf{候选细胞状态:}\quad \tilde{c}_t &= \tanh\Big(W_{xc}\, x_t + W_{hc}\, h_{t-1} + b_c\Big), \\
\textbf{细胞状态更新:}\quad c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t, \\
\textbf{隐藏状态:}\quad h_t &= o_t \odot \tanh(c_t).
\end{aligned}
直观理解
- 细胞状态 $c_t$:
细胞状态是贯穿整个序列的“记忆通道”,负责长期保存信息。它像一条传送带,在不同时间步中线性传递,避免信息被频繁修改,从而维持长期记忆。 - 遗忘门 $f_t$:
用于丢弃上一时刻不再需要的信息。如果遗忘门输出接近 0,说明遗忘了大部分过去的信息;如果接近 1,则保留大部分信息。 类比:若模型遇到新段落,遗忘门可能关闭(输出接近0),丢弃前一段的无关信息;若需要延续上下文(如故事主线),则保持开启(输出接近1)。 - 输入门
i_t
和候选细胞状态 $\tilde{c}_t$:
输入门控制有多少候选信息被写入细胞状态。候选细胞状态是基于当前输入和上一时刻隐藏状态生成的新信息。 类比:阅读时遇到关键情节,输入门打开,将新信息写入长期记忆(如角色关系),同时候选状态 $\tilde{c}_t$提供新信息的候选内容。 - 输出门 $o_t$:
控制从细胞状态中输出多少信息作为当前时间步的隐藏状态。隐藏状态h_t
通常用于后续计算(例如,生成输出、参与下一时刻计算)。 类比:根据当前任务(如预测下一个词),输出门决定暴露细胞状态的哪部分(如只关注时间、地点等关键信息)。