[B,T,V]
× [B,T]
的“对齐/掩码”两连坑(未 shift
+ PAD 未屏蔽)排障实录场景:训练中文小 GPT(Causal LM)。表面现象是:loss 不怎么降、ppl 比基线还高,但 demo 里单句续写“似乎还行
loss
在 3.9~4.2 震荡、ppl
高企;学习率/优化器如何调都难起色。ppl
比“用开源权重直接评估”还差;可视化注意力看到PAD/EOS 也被关注。两个问题都在这段“看似正常”的代码里:没做 shift,以及未正确屏蔽 PAD。
import torch, torch.nn.functional as F
def lm_loss_wrong(logits, labels, pad_id):
# logits: [B,T,V], labels: [B,T]
# ❌ 错 1:没 shift,模型“预测自己”
# ❌ 错 2:直接展平 + mean,PAD 也参与了平均
B, T, V = logits.shape
loss = F.cross_entropy(
logits.reshape(B*T, V),
labels.reshape(B*T),
ignore_index=pad_id, # 以为这就够了
reduction="mean",
)
return loss
为什么 batch=1 看起来更“能学”?
1️⃣ 打印“对齐后”的头尾 token
# 取一个样本看看
print("x[:8] =", labels[0, :8].tolist())
print("y[:8] =", labels[0, :8].tolist()) # 如果 y 就是 x,自检警报
如果你的目标
y
和输入x
一模一样、没有右移(去掉最后一位、前面补 BOS),那基本就是没做 shift。
2️⃣ 检查“逐位置有效比例”
mask = (labels != pad_id) # True=有效
valid_ratio = mask.float().mean().item()
print("valid ratio:", valid_ratio)
低于 0.6 甚至更低时,未经处理的平均会严重稀释梯度。搭配
label_smoothing
时尤其明显。
3️⃣ 定位“PAD 是否被看到”
SDPA
:把 attn_mask
可视化;nn.MultiheadAttention
:确认 key_padding_mask=True 表示屏蔽
(别用反了)。见我之前那篇《注意力 Mask 用反》;这里不赘述。
1️⃣ 正确的teacher forcing 右移(shift)
def shift_labels(input_ids, pad_id, bos_id=None):
# 目标是“预测下一 token”:y_t = x_{t+1}
# x: [B,T] y: [B,T],其中 y[..., -1] = PAD(或 EOS 后全 PAD)
y = input_ids.clone()
y[..., :-1] = input_ids[..., 1:]
y[..., -1] = pad_id
if bos_id is not None:
# 可选:若你想显式在 x 开头放 BOS,让第一个预测以 BOS 为条件
input_ids = input_ids.clone()
input_ids[..., 0] = bos_id
return y
2️⃣ 布尔掩码 + 对分母做精确归一(避免被 PAD 稀释)
import torch, torch.nn.functional as F
def lm_loss_masked(logits, labels, pad_id, label_smoothing=0.0):
"""
logits: [B,T,V], labels: [B,T]
- 右移后的 labels
- 对 PAD 做显式布尔屏蔽
- 分母用“有效 token 数”而不是 B*T
"""
B, T, V = logits.shape
mask = (labels != pad_id) # [B,T] True=有效
# 展平仅为计算方便
logits = logits.reshape(B*T, V)
labels = labels.reshape(B*T)
mask = mask.reshape(B*T)
# reduction='none' 拿到逐位置 loss,再按“有效数”归一
loss_per = F.cross_entropy(
logits, labels,
ignore_index=pad_id, # 与 mask 重合;为稳妥仍保留
label_smoothing=label_smoothing,
reduction="none",
)
valid = mask.sum().clamp_min(1) # 防止全 PAD 的 batch 除零
loss = (loss_per * mask).sum() / valid
return loss
为什么不用默认
mean
? 默认mean
的分母是所有位置数(或 PyTorch 内部基于 ignore_index 的约定),在大量 PAD 的批次上容易稀释;手动按有效数归一可控性更好,也便于日志统计。
3️⃣ 保证注意力看不到 PAD/未来位
nn.MultiheadAttention
:key_padding_mask=True=PAD
(别用反);loss
数千步内明显下降,ppl
接近公开基线;def sanity_check_alignment(input_ids, labels, pad_id):
# labels 应该等于 input_ids 右移一位(末尾 PAD)
shifted = input_ids.clone()
shifted[..., :-1] = input_ids[..., 1:]
shifted[..., -1] = pad_id
assert torch.equal((labels == shifted).all(dim=-1), torch.ones_like(labels[...,0], dtype=torch.bool)), \
"labels 与 input_ids 未对齐(缺少 shift)"
def sanity_check_loss_shape(logits, labels):
B, T, V = logits.shape
assert labels.shape == (B, T), f"labels 形状不对:{labels.shape} vs {(B,T)}"
def log_valid_ratio(labels, pad_id):
mask = (labels != pad_id)
r = mask.float().mean().item()
print(f"[dbg] valid_ratio={r:.3f}")
语言模型训练里,“对齐”和“掩码”是两条看不见却致命的红线。 一旦忘记右移或把 PAD 混进 loss 分母,再强的优化器也帮不了你。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。