
✍ 上一篇我们把现代大模型的两件“基础设施”——GQA 注意力 和 RMSNorm + Pre-Norm 细讲了一遍,从多头注意力的演化一路讲到归一化的升级。这一篇,我们就顺势把剩下的两件标配武器补上:
在最早的 Transformer 里,模型本身对“顺序”是没有感觉的,它只看到一串向量 : x_1, x_2, \dots, x_L \in \mathbb{R}^{d_{\text{model}}} 。他并不像RNN、LSTM一样具备循环机制,因此对于位置信息是不敏感。为了让模型知道“谁在前谁在后”,Transformer 直接给每个位置加了一个位置向量 PE_{pos}
\tilde{x}_{pos} = x_{pos} + PE_{pos}
Transformer 原始论文里的做法采用了三角函数位置编码:
PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right),\quad PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right))
三角函数编码是绝对位置编码其中一种经典实现方式,用固定的 sin/cos 函数给每个绝对位置生成向量,方便模型外推到更长序列。
直观理解: “我是谁”(token embedding) + “我在哪里”(position embedding) = 实际输入给 Transformer 的向量。
绝对 PE 的好处是实现很简单,但也有两点局限:
🧠 读到这里,读者可能会有点疑惑,为什么在前面说了 ① 三角函数编码方便模型外推到更长序列, 但是后面又说了② 对于超长上下文,learned pos embedding 很难直接外推,sin-cos 虽然能算,但模型未必学会用。
max_len = 2048,那 embedding table 里就只有 0~2047 这些 index;想用到 4096-long 序列时,根本没有 PE[3000] 这一行可用,得重新插值/扩表。但是,当我们采用了三角函数的位置编码时, 想算 pos=4096、pos=10000 都随时能算,从“函数定义”角度确实更“可外推”。因此这里根本不会自相矛盾,用一句土话讲就是“可以但没用的外推”。
import math
import torch
import torch.nn as nn
class PositionalEncoding(nn.Module):
"""
正弦 / 余弦绝对位置编码,接口风格跟 PyTorch Transformer 一致:
输入输出形状都是 [seq_len, batch_size, d_model]
"""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) # [max_len, 1]
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float32) *
(-math.log(10000.0) / d_model)
)
pe = torch.zeros(max_len, d_model, dtype=torch.float32)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维
pe = pe.unsqueeze(1) # [max_len, 1, d_model]
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
seq_len = x.size(0)
x = x + self.pe[:seq_len]
return self.dropout(x)🧠 RoPE(Rotary Positional Embedding)的核心思想可以一句话概括:
不再给 某个embedding 加位置向量,而是 直接在 Q/K 空间里做“按维度成对的旋转”,不同位置对应不同旋转角度。
先看一个二维的小例子。假设我们有一个 2D 向量$(x_1, x_2)$,在平面上旋转一个角度 $\theta$:
\begin{pmatrix} x'_1 \\ x'_2 \end{pmatrix} = \begin{pmatrix} \cos \theta & -\sin \theta \\ \sin \theta & \cos \theta \end{pmatrix} \begin{pmatrix} x_1 \\ x_2 \end{pmatrix}
RoPE 就是把 Q/K 的每两个维度看作一个二维坐标系上的点,然后:
如果 Q 的原始向量为Q_{pos} ,RoPE 输出的就是“带位置信息”的 Q'_{pos} ,同理对 K 也做同样的旋转。
直观理解: 不同位置的 token 被“旋转”到了不同方向上,注意力点积在比较两个向量时,自然就带上了相对位置差。
🧠 读者可能都知道旋转编码就是在Q K上进行旋转,但具体是怎么让模型知道了他们的相对位置信息呢?
假设 query 在位置 m ,key 在位置 n ,RoPE 分别对它们做旋转:
Q'_m = R_{\theta_m} Q_m,\quad K'_n = R_{\theta_n} K_n , 其中 $R_{\theta}$ 是旋转矩阵。
则注意力里用到的点积是:
Q'_m \cdot K'_n = (R_{\theta_m} Q_m)^\top (R_{\theta_n} K_n) = Q_m^\top R_{\theta_m}^\top R_{\theta_n} K_n
因为旋转矩阵可加角度:
R_{\theta_m}^\top R_{\theta_n} = R_{\theta_n - \theta_m}
所以:
Q'_m \cdot K'_n = Q_m^\top R_{\theta_n - \theta_m} K_n
也就是说: 注意力结果只和角度差 \theta_n - \theta_m 相关,也就是“相对位置”!
这就是 RoPE 在长上下文场景下比“纯绝对位置编码”更有优势的根本原因:模型更容易学到“离多远比较重要”,而不是死记“第 1234 个位置是什么样子”。
import torch
def rotate_half(x):
# x: [..., 2 * d_half]
x1, x2 = x.chunk(2, dim=-1) # 拆成两半
return torch.cat([-x2, x1], dim=-1) # (x1, x2) -> (-x2, x1)
def apply_rope(x, cos, sin):
"""
x: [B, L, H, D] # Q 或 K
cos: [L, 1, 1, D]
sin: [L, 1, 1, D]
"""
# 广播到同一形状
while cos.dim() < x.dim():
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
return x * cos + rotate_half(x) * sinQ = self.w_q(x) # [B, L, H*D]
K = self.w_k(x)
Q = Q.view(B, L, H, D)
K = K.view(B, L, H, D)
Q = apply_rope(Q, cos, sin) # 注入位置信息
K = apply_rope(K, cos, sin)对于每个位置的隐藏向量 $x \in \mathbb{R}^{d_{\text{model}}}$,Transformer中的FFN 本质就是一个逐位置的两层 MLP:
\text{FFN}(x) = W_2 \cdot \sigma(W_1 x + b_1) + b_2
其中:
自注意力解决“和谁交互”的问题,FFN 则在每个 token 自己的通道维度上,做一遍非线性变换,提升表达力。
SwiGLU 属于 GLU(Gated Linear Unit)家族。GLU 的典型形式是: \text{GLU}(x) = (W^v x) \odot \sigma(W^g x)
也就是用两条线性变换:
然后用 gate 去控制 value 的通过程度。
类比一下:
传统 FFN:所有维度统一通过一个激活函数; GLU/SwiGLU:每一维都有自己的“门”,可以决定这条通道要不要激活、激活多少
在 LLaMA 等模型中,用的是 SwiGLU 变体,大致可以写成:
\text{SwiGLU}(x) = \big(W^v x\big) \odot \text{SiLU}(W^g x)
其中 SiLU 激活为:
\text{SiLU}(z) = z \cdot \sigma(z)
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
def __init__(self, d_model, d_ff=4096, dropout=0.1):
super().__init__()
# 一次性投影到 2 * d_ff,然后一分为二:gate + value
self.w1 = nn.Linear(d_model, 2 * d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
x: [B, L, d_model]
"""
x_proj = self.w1(x) # [B, L, 2*d_ff]
gate, value = x_proj.chunk(2, dim=-1) # 各 [B, L, d_ff]
# SwiGLU:SiLU(gate) * value
x = F.silu(gate) * value
x = self.w2(self.dropout(x))
return x标准 FFN 只是一条 MLP 路径,所有通道共享同一个激活函数。而 SwiGLU 用两个投影产生 gate 和 value,再用 SiLU(gate) 做门控,让不同通道的信息流可以被独立控制,在同样的参数规模下提升表达能力。实验上,在 LLaMA / PaLM 等模型中,SwiGLU 相比简单的 GELU/ReLU 有更好的收敛和下游表现。
本章将四件套组合起来,编写一个Decoder代码:
import math
import torch
import torch.nn as nn
class RotaryEmbedding(nn.Module):
"""
RoPE 位置编码模块:
- 只负责根据 head_dim + seq_len 生成 cos/sin
- 不直接改 Q/K,在外面用 apply_rotary_pos_emb 处理
"""
def __init__(self, head_dim: int, max_position_embeddings: int = 4096, base: float = 10000.0):
super().__init__()
assert head_dim % 2 == 0, "head_dim 必须是偶数,才能两两配对旋转"
self.head_dim = head_dim
self.max_position_embeddings = max_position_embeddings
# inv_freq: [head_dim/2]
# 对应论文里的 1 / base^{2i/d}
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
self.register_buffer("inv_freq", inv_freq) # 不参与训练
# 预先算好最大长度的 cos/sin,后面按 seq_len 切片
self._build_cache(max_position_embeddings)
def _build_cache(self, max_seq_len: int):
# t: [max_seq_len]
t = torch.arange(max_seq_len, dtype=torch.float32, device=self.inv_freq.device)
# freqs: [max_seq_len, head_dim/2]
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# 扩成 [max_seq_len, head_dim]
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :]) # [1,1,L,D]
self.register_buffer("sin_cached", emb.sin()[None, None, :, :]) # [1,1,L,D]
def forward(self, seq_len: int, device=None):
"""
返回:
cos, sin: [1, 1, seq_len, head_dim]
"""
if seq_len > self.max_position_embeddings:
# 超过预设长度就重建缓存(简单写法,够用)
self.max_position_embeddings = seq_len
self._build_cache(seq_len)
cos = self.cos_cached[:, :, :seq_len, :] # [1,1,L,D]
sin = self.sin_cached[:, :, :seq_len, :] # [1,1,L,D]
if device is not None:
cos = cos.to(device)
sin = sin.to(device)
return cos, sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
将最后一维两两配对做 (x1, x2) -> (-x2, x1)
x: [..., D] 且 D 为偶数
"""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_pos_emb(x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor) -> torch.Tensor:
"""
RoPE 旋转操作:
x: [B, H, L, D]
cos: [1, 1, L, D]
sin: [1, 1, L, D]
"""
# 广播到 [B,H,L,D]
return x * cos + rotate_half(x) * sin
class RoPEMultiHeadAttention(nn.Module):
"""
带 RoPE 的多头注意力:
- 输入 / 输出: [B, L, d_model]
- 内部: 拆成 [B, H, L, Dh],对 Q/K 应用 RoPE
"""
def __init__(self, d_model, num_heads, dropout=0.0,
max_position_embeddings: int = 4096):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
# RoPE 模块,专门生成 cos/sin
self.rotary_emb = RotaryEmbedding(
head_dim=self.head_dim,
max_position_embeddings=max_position_embeddings
)
def forward(self, x, attn_mask=None):
"""
x: [B, L, d_model]
attn_mask: [B, 1, L, L] 或 [B, L, L],为 0 的位置会被 mask 掉
"""
B, L, _ = x.size()
device = x.device
# 1. 线性投影
Q = self.w_q(x) # [B, L, d_model]
K = self.w_k(x)
V = self.w_v(x)
# 2. 拆成多头 [B, H, L, Dh]
def split_heads(t):
return t.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
Q = split_heads(Q) # [B, H, L, Dh]
K = split_heads(K)
V = split_heads(V)
# 3. 生成 RoPE 的 cos/sin,并作用在 Q/K 上
cos, sin = self.rotary_emb(seq_len=L, device=device) # [1,1,L,Dh]
Q = apply_rotary_pos_emb(Q, cos, sin) # [B,H,L,Dh]
K = apply_rotary_pos_emb(K, cos, sin)
# 4. 缩放点积注意力
scores = Q @ K.transpose(-2, -1) / (self.head_dim ** 0.5) # [B,H,L,L]
if attn_mask is not None:
# 根据你项目里 attn_mask 的形状调整,这里假设 0 的地方是 mask 掉
if attn_mask.dim() == 3:
attn_mask = attn_mask.unsqueeze(1) # [B,1,L,L]
scores = scores.masked_fill(attn_mask == 0, float('-inf'))
attn = torch.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = attn @ V # [B,H,L,Dh]
# 5. 合并多头
out = out.transpose(1, 2).contiguous().view(B, L, self.d_model)
out = self.w_o(out) # [B,L,d_model]
return out
class SwiGLUFFN(nn.Module):
def __init__(self, d_model, d_ff=4096, dropout=0.1):
super().__init__()
self.w1 = nn.Linear(d_model, 2 * d_ff) # gate + value
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x_proj = self.w1(x) # [B, L, 2*d_ff]
gate, value = x_proj.chunk(2, dim=-1) # [B,L,d_ff] x2
x = torch.nn.functional.silu(gate) * value # SwiGLU
x = self.w2(self.dropout(x))
return x
class RMSNorm(nn.Module):
def __init__(self, d_model, eps=1e-8):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.eps = eps
def forward(self, x):
# x: [B,L,d_model]
rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
x_norm = x / rms
return self.weight * x_norm
class DecoderBlockWithRoPE(nn.Module):
"""
现代 LLM 风格的 Decoder Block:
- RoPE + MHA
- RMSNorm + Pre-Norm
- SwiGLU FFN
"""
def __init__(self, d_model, num_heads, d_ff=4096,
dropout=0.1, max_position_embeddings: int = 4096):
super().__init__()
self.self_attn = RoPEMultiHeadAttention(
d_model=d_model,
num_heads=num_heads,
dropout=dropout,
max_position_embeddings=max_position_embeddings,
)
self.ffn = SwiGLUFFN(d_model, d_ff, dropout)
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, attn_mask=None):
"""
x: [B, L, d_model]
"""
# 1) Pre-Norm + RoPE Self-Attention
h = self.norm1(x)
attn_out = self.self_attn(h, attn_mask=attn_mask)
x = x + self.dropout(attn_out)
# 2) Pre-Norm + SwiGLU FFN
h = self.norm2(x)
ffn_out = self.ffn(h)
x = x + self.dropout(ffn_out)
return x这一篇我们把另外两件标配武器补齐了:
到这里,已经把“现代 LLM 架构四件套:GQA / RoPE / SwiGLU / RMSNorm + Pre-Norm”串成一个整体故事了。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。