Attention Is All You Need — Transformer 论文精读

February 13, 2026

Attention Is All You Need — Transformer 论文精读

Attention Is All You Need

Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, Illia Polosukhin — NIPS 2017 (Google Brain / Google Research)

Transformer 由 Google 团队于 2017 年在论文《Attention Is All You Need》中提出,完全摒弃了 RNN 和 CNN,仅基于注意力机制构建序列到序列模型。这一架构后来成为 BERT、GPT 等大型语言模型的基础,彻底改变了 NLP 领域的发展方向。

研究动机

Transformer 的提出源于对 RNN 和 CNN 在序列建模中局限性的思考。

RNN 的局限

  • 难以并行:RNN 按时间步逐个计算隐藏状态 ht=f(ht1,xt)h_t = f(h_{t-1}, x_t),天然的时序依赖使其无法并行处理。
  • 长距离依赖困难:历史信息通过隐藏状态逐步传递,序列越长,早期信息越容易丢失。增大隐藏状态维度虽可缓解,但会显著增加内存开销。

CNN 的局限

  • 感受野有限:卷积每次只关注局部窗口,两个距离较远的位置需要堆叠多层卷积才能建立联系。

Transformer 的注意力机制在单层内即可建模序列中任意两个位置的关系,同时通过多头机制模拟 CNN 的多通道特性,兼顾全局建模能力和多模式捕捉能力。

模型架构

编码器-解码器结构

Transformer 采用经典的 Encoder-Decoder 结构:

  • 编码器(Encoder):将输入序列 (x1,x2,,xn)(x_1, x_2, \ldots, x_n) 编码为连续表示序列 (z1,z2,,zn)(z_1, z_2, \ldots, z_n),其中每个 ztz_t 是对应词的向量表示。
  • 解码器(Decoder):接收编码器输出 zz自回归(auto-regressive)地逐个生成输出序列 (y1,y2,,ym)(y_1, y_2, \ldots, y_m)。输出长度 mm 与输入长度 nn 可以不同(例如翻译任务中,一个英语单词不一定对应一个中文字)。

Transformer 整体架构

编码器

编码器由 N=6N=6 个相同的层堆叠而成,每个层包含两个子层:

  1. 多头自注意力机制(Multi-Head Self-Attention)
  2. 前馈神经网络(Feed-Forward Network)

每个子层都使用残差连接层归一化LayerNorm(x+Sublayer(x))\mathrm{LayerNorm}(x + \mathrm{Sublayer}(x))。为使残差连接生效,所有子层及 Embedding 层的输出维度统一为 dmodel=512d_{model} = 512

解码器

解码器同样由 N=6N=6 个相同的层堆叠,每个层包含三个子层:

编码器-解码器注意力连接

  1. 掩码多头自注意力(Masked Multi-Head Self-Attention):通过掩码确保位置 tt 的预测只能依赖 tt 之前的已知输出,防止信息泄露。
  2. 编码器-解码器注意力(Encoder-Decoder Attention):Query 来自解码器上一子层的输出,Key 和 Value 来自编码器的输出,使解码器能关注输入序列的所有位置。
  3. 前馈神经网络:与编码器相同,对每个位置独立进行非线性变换。

注意力机制

注意力函数的本质是将一个 Query 和一组 Key-Value 对映射到输出。输出是 Value 的加权和,权重由 Query 与对应 Key 的相似度决定。

缩放点积注意力(Scaled Dot-Product Attention)

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

缩放点积注意力计算流程

计算步骤:

  1. 点积:计算 Query 与所有 Key 的点积 QKTQK^T
  2. 缩放:除以 dk\sqrt{d_k}
  3. 掩码(可选):在解码器中,将未来位置的值设为 -\infty
  4. Softmax:将得分转化为概率分布
  5. 加权求和:用概率分布对 Value 加权

为什么要除以 dk\sqrt{d_k}dkd_k 较大时,点积结果的方差随之增大,导致 softmax 输出趋近于 one-hot 分布(某个值接近 1,其余接近 0)。在这种饱和区域梯度极小,模型难以训练。除以 dk\sqrt{d_k} 可以将方差稳定在合理范围。

Mask 的作用:训练阶段解码器的输入是完整的目标序列,但预测第 tt 个词时不应看到第 tt 个之后的内容(否则直接抄答案了)。具体实现是将 qtq_tkt+1,kt+2,k_{t+1}, k_{t+2}, \ldots 的注意力分数设为 -\infty,经过 softmax 后权重趋近于 0。

PyTorch 实现:

def masked_softmax(scores, valid_lens):
    """对 scores 做 softmax,将 valid_lens 之外的位置 mask 为 -inf。"""
    if valid_lens is None:
        return torch.softmax(scores, dim=-1)
    shape = scores.shape
    if valid_lens.dim() == 1:
        valid_lens = valid_lens.unsqueeze(1).expand(-1, shape[1])
    mask = torch.arange(shape[-1], device=scores.device).unsqueeze(0).unsqueeze(0)
    mask = mask >= valid_lens.unsqueeze(-1)
    scores = scores.masked_fill(mask, -1e9)
    return torch.softmax(scores, dim=-1)
 
 
class DotProductAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(attention_weights), values)

多头注意力(Multi-Head Attention)

多头注意力将输入分别通过 hh 组不同的线性投影,独立计算注意力后拼接输出:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

关键参数

参数含义论文取值
hh注意力头数8
dmodeld_{model}模型总维度512
dk=dv=dmodel/hd_k = d_v = d_{model} / h每个头的维度64

每个头的投影矩阵:WiQ,WiKRdmodel×dkW_i^Q, W_i^K \in \mathbb{R}^{d_{model} \times d_k}WiVRdmodel×dvW_i^V \in \mathbb{R}^{d_{model} \times d_v},输出投影矩阵 WORhdv×dmodelW^O \in \mathbb{R}^{hd_v \times d_{model}}

设计动机

  • 多模式捕捉:类比 CNN 的多输出通道,每个头可以学习不同的关注模式(如语法关系、语义依赖、上下文关联等)。
  • 总计算量不变:虽然有 hh 个头,但每个头的维度缩小为 dmodel/hd_{model} / h,总计算量与单头全维度注意力相当。

PyTorch 实现:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.attention = DotProductAttention(dropout)
 
    def _split_heads(self, X):
        """(batch, seq, d_model) → (batch*h, seq, d_k)"""
        batch, seq, _ = X.shape
        X = X.view(batch, seq, self.num_heads, self.d_k)
        X = X.transpose(1, 2)
        return X.reshape(batch * self.num_heads, seq, self.d_k)
 
    def _merge_heads(self, X, batch_size):
        """(batch*h, seq, d_k) → (batch, seq, d_model)"""
        X = X.view(batch_size, self.num_heads, -1, self.d_k)
        X = X.transpose(1, 2)
        return X.reshape(batch_size, -1, self.num_heads * self.d_k)
 
    def forward(self, queries, keys, values, valid_lens=None):
        batch_size = queries.shape[0]
        Q = self._split_heads(self.W_q(queries))
        K = self._split_heads(self.W_k(keys))
        V = self._split_heads(self.W_v(values))
        if valid_lens is not None:
            valid_lens = valid_lens.repeat_interleave(self.num_heads, dim=0)
        output = self.attention(Q, K, V, valid_lens)
        return self.W_o(self._merge_heads(output, batch_size))

维度验证:

MultiHeadAttention: torch.Size([2, 8, 32]) → torch.Size([2, 8, 32])

三种注意力的使用方式

Transformer 中注意力机制分为三种场景:

  1. 编码器自注意力:Q、K、V 均来自编码器上一层的输出,每个位置可关注输入序列的所有位置。
  2. 解码器掩码自注意力:Q、K、V 均来自解码器上一层的输出,通过掩码限制每个位置只能关注自身及之前的位置。
  3. 编码器-解码器注意力:Q 来自解码器,K 和 V 来自编码器输出,使解码器能够聚焦输入序列的相关部分。

Layer Normalization vs Batch Normalization

Transformer 使用 Layer Normalization(LN)而非 Batch Normalization(BN),原因与 NLP 数据的特性有关。

NLP 中输入数据是三维张量 (batch, seq, feature),其中 batch 是训练样本数,seq 是序列长度,feature 是特征维度。

归一化方式计算范围问题
BN沿 batch 维度,对每个特征位置计算均值和方差序列长度不等时需填充 0,导致统计量不稳定;推理时遇到超长序列,预训练的全局统计量不适用
LN沿 feature 维度,对每个样本独立计算均值和方差不受序列长度和 batch 内其他样本影响,统计量稳定

由于 NLP 任务中各样本的序列长度通常不一致,BN 在 batch 维度上计算统计量时会受到大量填充值的干扰,而 LN 在每个样本内部独立计算,天然适合变长序列场景。

前馈神经网络(FFN)

每个编码器/解码器层中都包含一个逐位置的前馈网络,对序列中的每个位置独立且相同地应用:

FFN(x)=max(0,  xW1+b1)W2+b2\text{FFN}(x) = \max(0,\; xW_1 + b_1)W_2 + b_2
  • 内层维度 dff=2048d_{ff} = 2048(将 512 维扩展至 2048 维)
  • 外层维度恢复为 dmodel=512d_{model} = 512
  • 激活函数:ReLU

Attention 与 FFN 的分工:Attention 负责在序列维度上汇聚全局信息,使每个位置的表示融合了全局上下文;FFN 随后对每个位置独立地进行语义空间转换。这与 RNN 形成对比——RNN 通过隐藏状态 ht1h_{t-1} 传递时序信息再经过 MLP 转换,而 Transformer 通过 Attention 一次性获取全局时序信息后再经过 FFN 转换。

Embedding 与位置编码

Embedding

Transformer 在三个位置使用 Embedding:

  • 编码器输入 Embedding:将输入词映射为 dmodeld_{model} 维向量
  • 解码器输入 Embedding:将已生成的输出词映射为 dmodeld_{model} 维向量
  • 输出层:解码器输出(1×dmodel1 \times d_{model})经线性层(dmodel×Vd_{model} \times V 参数矩阵)映射为词表大小的向量(1×V1 \times V),再经 Softmax 得到概率分布,取最大值对应的词作为预测结果

位置编码(Positional Encoding)

注意力机制本身是排列不变的——打乱输入顺序不会改变输出。为引入序列的顺序信息,Transformer 对输入 Embedding 加入位置编码:

PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)

其中 pospos 是位置索引,ii 是维度索引。位置编码与词嵌入维度相同(dmodel=512d_{model} = 512),两者直接相加。

PyTorch 实现:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=1000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32)
            * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div_term)  # 偶数维:sin
        pe[:, 1::2] = torch.cos(pos * div_term)  # 奇数维:cos
        self.register_buffer('pe', pe.unsqueeze(0))
 
    def forward(self, X):
        X = X + self.pe[:, :X.shape[1], :]
        return self.dropout(X)

复杂度比较

层类型每层计算复杂度顺序操作数最大路径长度
Self-AttentionO(n2d)O(n^2 \cdot d)O(1)O(1)O(1)O(1)
RecurrentO(nd2)O(n \cdot d^2)O(n)O(n)O(n)O(n)
ConvolutionalO(knd2)O(k \cdot n \cdot d^2)O(1)O(1)O(logk(n))O(\log_k(n))
Self-Attention (restricted)O(rnd)O(r \cdot n \cdot d)O(1)O(1)O(n/r)O(n/r)

其中 nn 为序列长度,dd 为表示维度,kk 为卷积核大小,rr 为受限自注意力的窗口大小。

Self-Attention 的计算复杂度推导:QKTQK^T(n×d)(d×n)=O(n2d)(n \times d)(d \times n) = O(n^2d)(QKT)V(QK^T)V(n×n)(n×d)=O(n2d)(n \times n)(n \times d) = O(n^2d),总复杂度为 O(n2d)O(n^2d)

关键优势:Self-Attention 的顺序操作数为 O(1)O(1)(完全可并行),且任意两个位置之间的最大路径长度为 O(1)O(1)(单层即可直接交互),这两点是 RNN 的 O(n)O(n) 和 CNN 的 O(logk(n))O(\log_k(n)) 无法比拟的。

总结

Transformer 的核心贡献在于证明了仅凭注意力机制即可构建强大的序列模型,无需 RNN 的循环结构或 CNN 的卷积操作。其并行计算能力、全局建模能力和灵活的架构设计使其成为后续 BERT、GPT、T5 等大型语言模型的基础架构,深刻改变了自然语言处理乃至整个深度学习领域的发展轨迹。

代码实战

完整的 Transformer 代码实现(英→法翻译任务),包含源代码实现与 nn.Transformer 简洁实现两种方式的对比:

Open In Colab

参考文献