Transformer架构深度解析:大模型的基石

前言

2017年,Google发表论文《Attention Is All You Need》,提出了Transformer架构。这篇论文彻底改变了NLP领域,成为GPT、BERT、LLaMA等大语言模型的共同基石。

本文将深入剖析Transformer的架构原理,从数学推导到代码实现,帮你彻底理解这个改变AI历史的设计。

为什么需要Transformer?

RNN/LSTM的局限

在Transformer之前,序列建模的主流方法是RNN和LSTM:

1
2
3
h₁ → h₂ → h₃ → ... → hₙ
↓ ↓ ↓ ↓
y₁ y₂ y₃ yₙ

主要问题

  1. 顺序计算瓶颈:无法并行,训练慢
  2. 长距离依赖:信息需要逐步传递,远距离token难以关联
  3. 梯度消失/爆炸:长序列训练困难

Transformer的解决方案

Transformer的核心思想:用Attention替代循环

1
2
3
4
5
6
┌─────────────────────────────────┐
│ 所有位置同时计算Self-Attention │
│ x₁ x₂ x₃ ... xₙ │
│ ↘ ↓ ↙ │
│ Parallel Computation │
└─────────────────────────────────┘

优势:

  • ✅ 完全并行化
  • ✅ 任意距离直接连接
  • ✅ 训练稳定

整体架构

Transformer采用Encoder-Decoder结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
                    Output Embedding

Positional Encoding

┌────────────────────────┴────────────────────────┐
│ DECODER (N layers) │
│ ┌─────────────────────────────────────────┐ │
│ │ Masked Multi-Head Attention │ │
│ │ ↓ │ │
│ │ Multi-Head Attention (Cross-Attention) │ │
│ │ ↓ │ │
│ │ Feed Forward │ │
│ └─────────────────────────────────────────┘ │
└────────────────────────┬────────────────────────┘

┌────────────────────────┴────────────────────────┐
│ ENCODER (N layers) │
│ ┌─────────────────────────────────────────┐ │
│ │ Multi-Head Attention │ │
│ │ ↓ │ │
│ │ Feed Forward │ │
│ └─────────────────────────────────────────┘ │
└────────────────────────┬────────────────────────┘

Input Embedding

Positional Encoding

现代LLM(如GPT)通常只用Decoder部分,我们重点分析核心组件。

核心组件详解

1. Self-Attention:自注意力机制

Self-Attention是Transformer的灵魂。它的核心思想是:让序列中的每个位置都能直接关注到其他所有位置

数学定义

给定输入序列 $X = [x_1, x_2, …, x_n]$,首先通过三个权重矩阵得到Query、Key、Value:

$$Q = XW_Q, \quad K = XW_K, \quad V = XW_V$$

Attention计算:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

其中 $d_k$ 是Key的维度,用于缩放以稳定梯度。

直觉理解

想象你在读一句话:”The animal didn’t cross the street because it was too tired”

当你读到”it”时,Self-Attention让模型能够”看向”前面的”animal”,建立关联:

1
2
3
Position:  The  animal  didn't  cross  the  street  because  it   was  tired

it attends to: [animal: 0.7, street: 0.1, tired: 0.15, ...]

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn
import math

class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads

assert (self.head_dim * heads == embed_size), "Embed size must be divisible by heads"

self.W_q = nn.Linear(embed_size, embed_size, bias=False)
self.W_k = nn.Linear(embed_size, embed_size, bias=False)
self.W_v = nn.Linear(embed_size, embed_size, bias=False)
self.fc_out = nn.Linear(embed_size, embed_size)

def forward(self, query, key, value, mask=None):
N = query.shape[0] # batch size
seq_len = query.shape[1]

# Linear projections: (N, seq_len, embed_size)
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)

# Reshape for multi-head: (N, seq_len, heads, head_dim) -> (N, heads, seq_len, head_dim)
Q = Q.view(N, seq_len, self.heads, self.head_dim).transpose(1, 2)
K = K.view(N, seq_len, self.heads, self.head_dim).transpose(1, 2)
V = V.view(N, seq_len, self.heads, self.head_dim).transpose(1, 2)

# Attention scores: (N, heads, seq_len, seq_len)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

# Apply mask (for decoder)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-1e20"))

# Softmax
attention = torch.softmax(scores, dim=-1)

# Apply attention to values: (N, heads, seq_len, head_dim)
out = torch.matmul(attention, V)

# Concat heads: (N, seq_len, embed_size)
out = out.transpose(1, 2).contiguous().view(N, seq_len, self.embed_size)

return self.fc_out(out)

2. Multi-Head Attention:多头注意力

单头Attention只能学习一种关联模式。Multi-Head让模型同时从多个”视角”学习:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, …, \text{head}_h)W_O$$

其中:
$$\text{head}_i = \text{Attention}(QW_Q^i, KW_K^i, VW_V^i)$$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
┌─────────────────────────────────────────────────┐
│ Multi-Head Attention │
│ │
│ Head 1 Head 2 Head 3 ... Head h │
│ ↓ ↓ ↓ ↓ │
│ Q₁K₁ᵀ Q₂K₂ᵀ Q₃K₃ᵀ ... QₕKₕᵀ │
│ ↓ ↓ ↓ ↓ │
│ softmax softmax softmax softmax │
│ ↓ ↓ ↓ ↓ │
│ V₁ V₂ V₃ Vₕ │
│ ↓ ↓ ↓ ↓ │
│ └───────────── Concat ─────────────────┘ │
│ ↓ │
│ Linear (W_O) │
└─────────────────────────────────────────────────┘

为什么需要多头?

不同Head可以学习不同的关联模式:

  • Head 1:关注语法结构
  • Head 2:关注语义相似性
  • Head 3:关注位置关系

3. Positional Encoding:位置编码

Self-Attention是位置无关的——打乱输入顺序,输出只是相应的打乱。为了引入位置信息,Transformer添加位置编码:

$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})$$

其中 $pos$ 是位置,$i$ 是维度索引。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class PositionalEncoding(nn.Module):
def __init__(self, embed_size, max_len=5000):
super(PositionalEncoding, self).__init__()

# Create matrix of (max_len, embed_size)
pe = torch.zeros(max_len, embed_size)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

# Compute the positional encodings
div_term = torch.exp(
torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size)
)

pe[:, 0::2] = torch.sin(position * div_term) # even indices
pe[:, 1::2] = torch.cos(position * div_term) # odd indices

# Add batch dimension: (1, max_len, embed_size)
pe = pe.unsqueeze(0)

self.register_buffer('pe', pe)

def forward(self, x):
# x: (N, seq_len, embed_size)
return x + self.pe[:, :x.shape[1], :]

为什么用sin/cos?

  • 可以泛化到任意长度
  • 相对位置可以通过线性变换得到:$PE_{pos+k}$ 可以表示为 $PE_{pos}$ 的函数

4. Feed-Forward Network

每个Transformer Block包含一个前馈网络:

$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$

通常是两层全连接,中间有ReLU激活,隐藏层维度是输入的4倍。

1
2
3
4
5
6
7
8
9
class FeedForward(nn.Module):
def __init__(self, embed_size, d_ff):
super(FeedForward, self).__init__()
self.linear1 = nn.Linear(embed_size, d_ff)
self.linear2 = nn.Linear(d_ff, embed_size)
self.relu = nn.ReLU()

def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))

5. Layer Normalization & Residual Connection

每个子层后都有残差连接和LayerNorm:

$$\text{LayerNorm}(x + \text{Sublayer}(x))$$

1
2
3
4
5
6
7
8
9
10
11
class LayerNorm(nn.Module):
def __init__(self, embed_size, eps=1e-6):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(embed_size))
self.beta = nn.Parameter(torch.zeros(embed_size))
self.eps = eps

def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta

残差连接的作用

  • 缓解梯度消失
  • 允许训练更深的网络
  • 保留原始信息流

完整的Transformer Block

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class TransformerBlock(nn.Module):
def __init__(self, embed_size, heads, d_ff, dropout=0.1):
super(TransformerBlock, self).__init__()

self.attention = SelfAttention(embed_size, heads)
self.norm1 = LayerNorm(embed_size)
self.norm2 = LayerNorm(embed_size)

self.feed_forward = FeedForward(embed_size, d_ff)

self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
# Self-Attention with residual + norm
attention = self.attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attention))

# Feed-Forward with residual + norm
forward = self.feed_forward(x)
x = self.norm2(x + self.dropout(forward))

return x

Encoder完整实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Encoder(nn.Module):
def __init__(
self,
vocab_size,
embed_size,
num_layers,
heads,
d_ff,
max_len=5000,
dropout=0.1
):
super(Encoder, self).__init__()

self.embedding = nn.Embedding(vocab_size, embed_size)
self.positional_encoding = PositionalEncoding(embed_size, max_len)

self.layers = nn.ModuleList([
TransformerBlock(embed_size, heads, d_ff, dropout)
for _ in range(num_layers)
])

self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
N, seq_len = x.shape

# Embedding + Positional Encoding
x = self.embedding(x)
x = self.positional_encoding(x)
x = self.dropout(x)

# Pass through transformer blocks
for layer in self.layers:
x = layer(x, mask)

return x

Decoder的特殊设计

Decoder有两个关键区别:

1. Masked Self-Attention

在训练时,Decoder不能”看到”未来的token:

1
2
3
4
5
6
7
8
9
10
11
def create_mask(seq_len):
"""Create upper triangular mask to prevent attending to future tokens"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)

# Example: seq_len = 5
# [[1, 0, 0, 0, 0],
# [1, 1, 0, 0, 0],
# [1, 1, 1, 0, 0],
# [1, 1, 1, 1, 0],
# [1, 1, 1, 1, 1]]

2. Cross-Attention

Decoder的第二个Attention层使用Encoder的输出作为Key和Value:

1
2
3
4
5
6
7
# In Decoder block
# Query from Decoder, Key/Value from Encoder
cross_attention = self.attention(
query=x, # Decoder state
key=encoder_out, # Encoder output
value=encoder_out # Encoder output
)

从Transformer到GPT

GPT(Generative Pre-trained Transformer)只使用Decoder部分:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
┌─────────────────────────────────────┐
│ GPT Architecture │
│ │
│ Token Embedding + Position Embed │
│ ↓ │
│ ┌─────────────────────────────┐ │
│ │ Masked Multi-Head Attention │ │ ← 只能看左边
│ │ ↓ │ │
│ │ Add & Norm │ │
│ │ ↓ │ │
│ │ Feed Forward │ │
│ │ ↓ │ │
│ │ Add & Norm │ │
│ └─────────────────────────────┘ │
│ × N layers │
│ ↓ │
│ Linear (vocab_size) │
│ ↓ │
│ Softmax │
│ ↓ │
│ Next Token Probability │
└─────────────────────────────────────┘

训练目标:预测下一个token

$$\mathcal{L} = -\sum_{t=1}^{T} \log P(x_t | x_{<t})$$

从Transformer到BERT

BERT(Bidirectional Encoder Representations from Transformers)只使用Encoder部分:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
┌─────────────────────────────────────┐
│ BERT Architecture │
│ │
│ [CLS] Token1 Token2 ... [SEP] │
│ ↓ │
│ Token + Position + Segment Embed │
│ ↓ │
│ ┌─────────────────────────────┐ │
│ │ Multi-Head Attention │ │ ← 双向!
│ │ ↓ │ │
│ │ Add & Norm │ │
│ │ ↓ │ │
│ │ Feed Forward │ │
│ │ ↓ │ │
│ │ Add & Norm │ │
│ └─────────────────────────────┘ │
│ × N layers │
│ ↓ │
│ [CLS] → Classification │
│ Tokens → Token Predictions │
└─────────────────────────────────────┘

训练任务

  1. Masked Language Model (MLM):随机mask 15%的token,预测原token
  2. Next Sentence Prediction (NSP):判断两句是否连续

现代LLM的架构演进

模型 架构 参数量 关键创新
GPT-1 Decoder 117M 无监督预训练 + 有监督微调
BERT Encoder 340M 双向编码 + MLM
GPT-2 Decoder 1.5B 更大更多数据
GPT-3 Decoder 175B In-Context Learning
LLaMA Decoder 7B-65B RMSNorm + SwiGLU + RoPE
GPT-4 Decoder 未公开 多模态 + MoE(推测)

LLaMA的改进

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# 1. RMSNorm (替代LayerNorm)
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def forward(self, x):
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x * norm).type_as(x) * self.weight

# 2. SwiGLU (替代ReLU FFN)
class SwiGLU(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)

def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))

# 3. RoPE (替代绝对位置编码)
def precompute_freqs_cis(dim, max_seq_len, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis

计算复杂度分析

操作 复杂度 说明
Self-Attention $O(n^2 \cdot d)$ n是序列长度,d是维度
Feed-Forward $O(n \cdot d^2)$ 线性于序列长度
总复杂度 $O(n^2 \cdot d + n \cdot d^2)$

长序列问题:Attention的 $O(n^2)$ 是主要瓶颈。

解决方案

  • Sparse Attention(稀疏注意力)
  • Linear Attention
  • Flash Attention(IO优化)
  • Sliding Window Attention

总结

Transformer的核心创新:

  1. Self-Attention:任意位置直接关联,无距离限制
  2. Multi-Head:多视角学习不同关联模式
  3. Positional Encoding:注入位置信息
  4. 残差连接 + LayerNorm:训练稳定,支持深层网络
  5. 完全并行:摆脱RNN的顺序计算

从Transformer到GPT/BERT,再到LLaMA/GPT-4,架构的演进始终围绕这些核心设计。理解Transformer,就是理解了大模型时代的基石。


参考资料

感谢你的阅读,如果文章对你有帮助,可以请作者喝杯茶!