残差连接(Residual Connection)是Transformer中的一个关键设计,用于解决深层网络训练时的梯度消失问题,同时帮助模型保留原始输入信息。它的操作非常简单,但效果显著。以下是通俗易懂的解释:
1. 残差连接的核心操作
一句话总结:
把当前层的输入直接加到当前层的输出上,形成“输入 + 输出”的短路路径。
数学公式:
输出 = 输入 x + 当前层的变换 ( x ) \text{输出} = \text{输入} \ x \ + \ \text{当前层的变换}(x) 输出=输入 x + 当前层的变换(x)
(其中“当前层的变换”可能是自注意力、交叉注意力或前馈网络)
2. 具体步骤(以解码器的自注意力层为例)
假设输入是一个向量 ( x )(已包含词嵌入和位置编码),经过自注意力层后的输出为 SelfAttn ( x ) \text{SelfAttn}(x) SelfAttn(x)
- 保留原始输入:将输入 ( x ) 复制一份。
- 叠加变换结果:将自注意力的输出 SelfAttn ( x ) \text{SelfAttn}(x) SelfAttn(x)与原始输入 ( x ) 逐元素相加。
残差输出 = x + SelfAttn ( x ) \text{残差输出} = x + \text{SelfAttn}(x) 残差输出=x+SelfAttn(x) - 层归一化:对相加后的结果做归一化(LayerNorm)。
最终输出 = LayerNorm ( x + SelfAttn ( x ) ) \text{最终输出} = \text{LayerNorm}(x + \text{SelfAttn}(x)) 最终输出=LayerNorm(x+SelfAttn(x))
3. 直观类比
想象你正在修改一篇文章:
- 原始输入(x):初稿的文本。
- 当前层的变换(SelfAttn(x)):你写的修改建议(比如添加一些描述)。
- 残差连接:把修改建议直接“贴”到初稿上(初稿 + 修改),而不是完全重写。
- 层归一化:调整合并后的格式,使其更规范。
关键点:无论你修改多少遍,初稿的内容始终保留,避免彻底丢失原始信息。
4. 为什么需要残差连接?
解决的问题
- 梯度消失:深层网络中,反向传播时梯度可能逐层衰减,导致浅层参数无法更新。残差连接提供了直通路径,让梯度能直接回传。
- 信息丢失:传统网络可能过度修改输入,残差连接强制模型只学习“需要补充或调整的部分”(即残差)。
对比实验
- 不带残差连接:Transformer在6层以上时,训练损失难以收敛。
- 带残差连接:即使堆叠100层,模型仍能稳定训练(如GPT-3)。
5. 代码示例(PyTorch风格)
import torch
import torch.nn as nnclass DecoderLayer(nn.Module):def __init__(self, d_model):super().__init__()self.self_attn = MultiHeadAttention(d_model) # 自注意力层self.norm = nn.LayerNorm(d_model) # 层归一化def forward(self, x):# 残差连接:输入x + 自注意力输出residual = xx = self.self_attn(x)x = self.norm(residual + x) # 先相加,再归一化return x
6. 残差连接的变体
- 经典残差:
输出 = 输入 + 变换(输入)
(Transformer采用)。 - 预激活残差:先归一化再变换(如ResNet v2)。
- 自适应残差:动态调整残差权重(如门控机制)。
7. 总结
- 操作:输入与输出直接相加。
- 目的:保留原始信息,缓解梯度消失,稳定深层训练。
- 效果:让Transformer可以堆叠数十层甚至上百层,仍能高效学习。
类比记忆:
残差连接就像“写论文时保留初稿,每次修改只添加批注”——既避免推倒重来,又能逐步完善。