前言
2017年,Google发表论文《Attention Is All You Need》,提出了Transformer架构。这篇论文彻底改变了NLP领域,成为GPT、BERT、LLaMA等大语言模型的共同基石。
本文将深入剖析Transformer的架构原理,从数学推导到代码实现,帮你彻底理解这个改变AI历史的设计。
为什么需要Transformer?
RNN/LSTM的局限
在Transformer之前,序列建模的主流方法是RNN和LSTM:
1 | h₁ → h₂ → h₃ → ... → hₙ |
主要问题:
- 顺序计算瓶颈:无法并行,训练慢
- 长距离依赖:信息需要逐步传递,远距离token难以关联
- 梯度消失/爆炸:长序列训练困难
Transformer的解决方案
Transformer的核心思想:用Attention替代循环
1 | ┌─────────────────────────────────┐ |
优势:
- ✅ 完全并行化
- ✅ 任意距离直接连接
- ✅ 训练稳定
整体架构
Transformer采用Encoder-Decoder结构:
1 | Output Embedding |
现代LLM(如GPT)通常只用Decoder部分,我们重点分析核心组件。
核心组件详解
1. Self-Attention:自注意力机制
Self-Attention是Transformer的灵魂。它的核心思想是:让序列中的每个位置都能直接关注到其他所有位置。
数学定义
给定输入序列 $X = [x_1, x_2, …, x_n]$,首先通过三个权重矩阵得到Query、Key、Value:
$$Q = XW_Q, \quad K = XW_K, \quad V = XW_V$$
Attention计算:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
其中 $d_k$ 是Key的维度,用于缩放以稳定梯度。
直觉理解
想象你在读一句话:”The animal didn’t cross the street because it was too tired”
当你读到”it”时,Self-Attention让模型能够”看向”前面的”animal”,建立关联:
1 | Position: The animal didn't cross the street because it was tired |
代码实现
1 | import torch |
2. Multi-Head Attention:多头注意力
单头Attention只能学习一种关联模式。Multi-Head让模型同时从多个”视角”学习:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, …, \text{head}_h)W_O$$
其中:
$$\text{head}_i = \text{Attention}(QW_Q^i, KW_K^i, VW_V^i)$$
1 | ┌─────────────────────────────────────────────────┐ |
为什么需要多头?
不同Head可以学习不同的关联模式:
- Head 1:关注语法结构
- Head 2:关注语义相似性
- Head 3:关注位置关系
- …
3. Positional Encoding:位置编码
Self-Attention是位置无关的——打乱输入顺序,输出只是相应的打乱。为了引入位置信息,Transformer添加位置编码:
$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})$$
其中 $pos$ 是位置,$i$ 是维度索引。
1 | class PositionalEncoding(nn.Module): |
为什么用sin/cos?
- 可以泛化到任意长度
- 相对位置可以通过线性变换得到:$PE_{pos+k}$ 可以表示为 $PE_{pos}$ 的函数
4. Feed-Forward Network
每个Transformer Block包含一个前馈网络:
$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$
通常是两层全连接,中间有ReLU激活,隐藏层维度是输入的4倍。
1 | class FeedForward(nn.Module): |
5. Layer Normalization & Residual Connection
每个子层后都有残差连接和LayerNorm:
$$\text{LayerNorm}(x + \text{Sublayer}(x))$$
1 | class LayerNorm(nn.Module): |
残差连接的作用:
- 缓解梯度消失
- 允许训练更深的网络
- 保留原始信息流
完整的Transformer Block
1 | class TransformerBlock(nn.Module): |
Encoder完整实现
1 | class Encoder(nn.Module): |
Decoder的特殊设计
Decoder有两个关键区别:
1. Masked Self-Attention
在训练时,Decoder不能”看到”未来的token:
1 | def create_mask(seq_len): |
2. Cross-Attention
Decoder的第二个Attention层使用Encoder的输出作为Key和Value:
1 | # In Decoder block |
从Transformer到GPT
GPT(Generative Pre-trained Transformer)只使用Decoder部分:
1 | ┌─────────────────────────────────────┐ |
训练目标:预测下一个token
$$\mathcal{L} = -\sum_{t=1}^{T} \log P(x_t | x_{<t})$$
从Transformer到BERT
BERT(Bidirectional Encoder Representations from Transformers)只使用Encoder部分:
1 | ┌─────────────────────────────────────┐ |
训练任务:
- Masked Language Model (MLM):随机mask 15%的token,预测原token
- Next Sentence Prediction (NSP):判断两句是否连续
现代LLM的架构演进
| 模型 | 架构 | 参数量 | 关键创新 |
|---|---|---|---|
| GPT-1 | Decoder | 117M | 无监督预训练 + 有监督微调 |
| BERT | Encoder | 340M | 双向编码 + MLM |
| GPT-2 | Decoder | 1.5B | 更大更多数据 |
| GPT-3 | Decoder | 175B | In-Context Learning |
| LLaMA | Decoder | 7B-65B | RMSNorm + SwiGLU + RoPE |
| GPT-4 | Decoder | 未公开 | 多模态 + MoE(推测) |
LLaMA的改进
1 | # 1. RMSNorm (替代LayerNorm) |
计算复杂度分析
| 操作 | 复杂度 | 说明 |
|---|---|---|
| Self-Attention | $O(n^2 \cdot d)$ | n是序列长度,d是维度 |
| Feed-Forward | $O(n \cdot d^2)$ | 线性于序列长度 |
| 总复杂度 | $O(n^2 \cdot d + n \cdot d^2)$ |
长序列问题:Attention的 $O(n^2)$ 是主要瓶颈。
解决方案:
- Sparse Attention(稀疏注意力)
- Linear Attention
- Flash Attention(IO优化)
- Sliding Window Attention
总结
Transformer的核心创新:
- Self-Attention:任意位置直接关联,无距离限制
- Multi-Head:多视角学习不同关联模式
- Positional Encoding:注入位置信息
- 残差连接 + LayerNorm:训练稳定,支持深层网络
- 完全并行:摆脱RNN的顺序计算
从Transformer到GPT/BERT,再到LLaMA/GPT-4,架构的演进始终围绕这些核心设计。理解Transformer,就是理解了大模型时代的基石。
参考资料:
- Attention Is All You Need - 原始论文
- The Illustrated Transformer - 可视化讲解
- Annotated Transformer - 带注释的实现
- LLaMA Paper - 现代架构改进