md_files/科研/循环神经网络.md
2025-03-18 12:46:59 +08:00

9.9 KiB
Raw Blame History

循环神经网络RNN

循环神经网络Recurrent Neural Network简称RNN是一类专门用于处理序列数据的神经网络模型。与传统的前馈神经网络不同RNN具有“记忆”功能能够捕捉数据序列中的时间依赖关系。

基本结构

RNN的核心在于它的循环结构这个结构使得信息可以沿着时间步流动。一个典型的RNN单元在时间步 t 接收输入向量 x_t 和前一时刻的隐藏状态 $h_{t-1}$,然后计算当前时刻的隐藏状态 $h_t$。这种循环过程允许模型利用之前的状态信息来影响当前的预测。

隐藏状态的更新

隐藏状态更新通常通过如下公式实现:


h_t = f(W_{xh} \cdot x_t + W_{hh} \cdot h_{t-1} + b_h)

其中:

  • h_t 表示时间步 t 的隐藏状态(所有隐藏层神经元激活值的集合。)。
  • x_t 是时间步 t 的输入向量。
  • W_{xh} 是输入到隐藏状态的权重矩阵。
  • W_{hh} 是隐藏状态之间的递归连接权重矩阵。
  • b_h 是偏置项。
  • f 是激活函数通常会选择非线性函数如tanh或ReLU以引入非线性变换。

在这种更新过程中,当前的隐藏状态 h_t 同时依赖于当前的输入 x_t 和之前的隐藏状态 $h_{t-1}$这使得RNN能够捕捉长时间序列中的上下文关系。

输出层

有时RNN还会在每个时间步产生输出输出计算方式通常为


y_t = g(W_{hy} \cdot h_t + b_y)

其中:

  • y_t 是时间步 t 的输出。
  • W_{hy} 是隐藏状态到输出的权重矩阵。
  • b_y 是输出层的偏置项。
  • g 是输出层激活函数例如softmax用于分类任务

困惑度

假设我们有一个测试序列,其中包含 3 个单词,模型对每个单词的预测概率分别为:

  • P(w_1) = 0.5
  • P(w_2|w_1) = 0.2
  • P(w_3|w_1, w_2) = 0.1

根据困惑度的公式:


\text{Perplexity} = \exp\left(-\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | \text{context})\right)

当模型对每个单词都能百分之百预测即概率为1则平均交叉熵为0困惑度为 $\exp(0)=1$。这表示模型没有任何不确定性,是理想状态。

我们这里 $N=3$。下面是具体的计算步骤:

  1. 计算每个单词的对数概率

\log P(w_1) = \log(0.5) \approx -0.6931

\log P(w_2|w_1) = \log(0.2) \approx -1.6094

\log P(w_3|w_1, w_2) = \log(0.1) \approx -2.3026
  1. 求和并求平均

将这些对数值相加:


\sum_{i=1}^{3} \log P(w_i|\text{context}) = -0.6931 - 1.6094 - 2.3026 \approx -4.6051

然后求平均:


\text{平均对数概率} = \frac{-4.6051}{3} \approx -1.5350
  1. 计算困惑度

取负值再求指数:


\text{Perplexity} = \exp\left(1.5350\right) \approx 4.64

训练过程与挑战

整体训练流程可以总结为下面几个步骤,每个 epoch 都会重复这些步骤:

  1. 前向传播
    • 对于一个完整的句子(或者一个批次中的多个句子),模型按顺序处理所有时间步,生成每个时间步的输出。
    • 比如,对于句子“我 爱 编程”,模型会依次处理“我”、“爱”、“编程”,得到对应的输出(例如每个时间步预测下一个词的概率分布)。
  2. 计算损失
    • 将模型在所有时间步的输出与真实目标序列(也就是每个时间步的正确答案)进行比较,计算整体损失。
    • 损失通常是所有时间步损失的总和或平均值,例如均方误差或交叉熵损失。
  3. 反向传播BPTT
    • 整个句子进行反向传播即通过时间Back Propagation Through TimeBPTT计算所有时间步的梯度。
    • 这一步会利用链式法则,把整个序列中各个时间步的梯度累积起来,形成每个参数的总梯度。
  4. 参数更新
    • 使用优化器(如 Adam、SGD 等)根据计算得到的梯度更新模型参数。
  5. 重复整个过程
    • 以上步骤构成了一个训练迭代周期(一个 epoch在一个 epoch 中,所有训练样本都会被送入模型进行训练。
    • 然后在下一个 epoch 中,再次重复整个流程,直到达到预设的 epoch 数或满足其他停止条件。

在训练过程中RNN通过反向传播算法具体为“反向传播通过时间”BPTT来更新参数。然而由于梯度在长序列上传播时可能出现梯度消失或梯度爆炸问题使得RNN在捕捉长程依赖关系时面临挑战。为此后来发展出了如长短时记忆网络LSTM和门控循环单元GRU等改进模型它们在结构上增加了门控机制有效缓解了这一问题。

门控循环单元GRU

GRUGated Recurrent Unit门控循环单元是一种常用的循环神经网络变种旨在解决标准 RNN 中梯度消失或梯度爆炸的问题,同时比 LSTM 结构更简单。

基本结构

GRU 通过两个门gate来控制信息的流动

  1. 更新门 $z_t$
    控制当前隐藏状态需要保留多少来自过去的信息以及引入多少新的信息。
  2. 重置门 $r_t$
    决定如何结合新输入和过去的记忆,尤其是在产生候选隐藏状态时。

另外GRU 计算一个候选隐藏状态 $\tilde{h}_t$,并结合更新门 z_t 的信息,更新最终的隐藏状态 $h_t$。

隐藏状态更新公式

对于每个时间步 $t$GRU 的计算过程通常包括以下步骤:

  1. 更新门 $z_t$

    
    z_t = \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z)
    

    其中:

    • x_t 是当前时间步的输入;
    • h_{t-1} 是上一时刻的隐藏状态;
    • W_zU_z 是权重矩阵;
    • b_z 是偏置向量;
    • \sigma(\cdot) 是 sigmoid 函数,用于将输出限制在 [0, 1] 区间。
  2. 重置门 $r_t$

    
    r_t = \sigma(W_{xr} x_t + W_{hr} h_{t-1} + b_r)
    

    其中参数意义与更新门类似,重置门决定忘记多少过去的信息。

  3. 候选隐藏状态 $\tilde{h}_t$

    
    \tilde{h}_t = \tanh(W_{xh} x_t + W_{hh} (r_t \odot h_{t-1}) + b_h)
    

    这里:

    • r_t \odot h_{t-1} 表示重置门 r_t 和上一时刻隐藏状态的逐元素相乘Hadamard 乘积),用以调制历史信息的影响;
    • \tanh(\cdot) 用来生成候选隐藏状态,将输出限制在 $[-1, 1]$。
  4. 最终隐藏状态 $h_t$
    GRU 结合更新门和候选隐藏状态更新最终隐藏状态:

    
    h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t.
    

    这表明更新门 z_t 决定了新信息 \tilde{h}_t 与旧信息 h_{t-1} 的比例。

公式

GRU 更新公式如下:


\begin{aligned}
z_t &= \sigma(W_{xz} x_t + W_{hz} h_{t-1} + b_z), \\
r_t &= \sigma(W_{xr} x_t + W_{hr} h_{t-1} + b_r), \\
\tilde{h}_t &= \tanh(W_{xh} x_t + W_{hh}(r_t \odot h_{t-1}) + b_h), \\
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t.
\end{aligned}

长短时记忆网络LSTM

LSTM 是一种常用的循环神经网络变种,专门为解决标准 RNN 中的梯度消失问题而设计。它通过引入额外的“记忆单元”和多个门控机制,有效地控制信息的保存、遗忘和输出,从而捕捉长距离的依赖关系。


基本结构

LSTM 的核心在于其“细胞状态”cell state这是一个贯穿整个序列传递的信息流同时有三个主要的门gate来控制细胞状态的更新过程

  1. 遗忘门 $f_t$
    决定当前时间步需要遗忘多少之前的记忆信息。

  2. 输入门 $i_t$
    决定当前时间步有多少新的信息写入细胞状态。

  3. 输出门 $o_t$
    决定当前时间步从细胞状态中输出多少信息作为隐藏状态。

此外,还引入了一个候选细胞状态 \tilde{c}_t 用于更新细胞状态。


隐藏状态更新公式

对于每个时间步 $t$LSTM 的更新过程通常可以写为以下公式(所有权重矩阵用 WU 表示,各门的偏置为 $b$


\begin{aligned}
\textbf{遗忘门:}\quad f_t &= \sigma\Big(W_{xf}\, x_t + W_{hf}\, h_{t-1} + b_f\Big), \\
\textbf{输入门:}\quad i_t &= \sigma\Big(W_{xi}\, x_t + W_{hi}\, h_{t-1} + b_i\Big), \\
\textbf{输出门:}\quad o_t &= \sigma\Big(W_{xo}\, x_t + W_{ho}\, h_{t-1} + b_o\Big), \\
\\
\textbf{候选细胞状态:}\quad \tilde{c}_t &= \tanh\Big(W_{xc}\, x_t + W_{hc}\, h_{t-1} + b_c\Big), \\
\textbf{细胞状态更新:}\quad c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t, \\

\textbf{隐藏状态:}\quad h_t &= o_t \odot \tanh(c_t).
\end{aligned}

直观理解

  • 细胞状态 $c_t$
    细胞状态是贯穿整个序列的“记忆通道”,负责长期保存信息。它像一条传送带,在不同时间步中线性传递,避免信息被频繁修改,从而维持长期记忆。
  • 遗忘门 $f_t$
    用于丢弃上一时刻不再需要的信息。如果遗忘门输出接近 0说明遗忘了大部分过去的信息如果接近 1则保留大部分信息。 类比若模型遇到新段落遗忘门可能关闭输出接近0丢弃前一段的无关信息若需要延续上下文如故事主线则保持开启输出接近1
  • 输入门 i_t 和候选细胞状态 $\tilde{c}_t$
    输入门控制有多少候选信息被写入细胞状态。候选细胞状态是基于当前输入和上一时刻隐藏状态生成的新信息。 类比:阅读时遇到关键情节,输入门打开,将新信息写入长期记忆(如角色关系),同时候选状态 $\tilde{c}_t$提供新信息的候选内容。
  • 输出门 $o_t$
    控制从细胞状态中输出多少信息作为当前时间步的隐藏状态。隐藏状态 h_t 通常用于后续计算(例如,生成输出、参与下一时刻计算)。 类比:根据当前任务(如预测下一个词),输出门决定暴露细胞状态的哪部分(如只关注时间、地点等关键信息)。