首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >【Debug日志 | LM Loss下降问题】

【Debug日志 | LM Loss下降问题】

原创
作者头像
九年义务漏网鲨鱼
发布2025-09-12 16:07:38
发布2025-09-12 16:07:38
6100
代码可运行
举报
文章被收录于专栏:tencent cloudtencent cloud
运行总次数:0
代码可运行

语言模型 Loss 死活降不下去:[B,T,V] × [B,T] 的“对齐/掩码”两连坑(未 shift + PAD 未屏蔽)排障实录

场景:训练中文小 GPT(Causal LM)。表面现象是:loss 不怎么降ppl 比基线还高,但 demo 里单句续写“似乎还行

❓ Bug 现象

  • 训练数千 step:loss 在 3.9~4.2 震荡、ppl 高企;学习率/优化器如何调都难起色。
  • batch=1时似乎能慢慢降;batch 变大曲线更“平”。
  • 评估集 ppl 比“用开源权重直接评估”还差;可视化注意力看到PAD/EOS 也被关注。

📽️ 场景复现

两个问题都在这段“看似正常”的代码里:没做 shift,以及未正确屏蔽 PAD

代码语言:python
代码运行次数:0
运行
复制
import torch, torch.nn.functional as F

def lm_loss_wrong(logits, labels, pad_id):
    # logits: [B,T,V], labels: [B,T]
    # ❌ 错 1:没 shift,模型“预测自己”
    # ❌ 错 2:直接展平 + mean,PAD 也参与了平均
    B, T, V = logits.shape
    loss = F.cross_entropy(
        logits.reshape(B*T, V),
        labels.reshape(B*T),
        ignore_index=pad_id,        # 以为这就够了
        reduction="mean",
    )
    return loss

为什么 batch=1 看起来更“能学”?

  • 没做 shift 时,当前 token 与目标 token 完全相同,模型只需学到“把当前 token 概率抬高”就能拿到不差的 loss;
  • batch 大时,PAD/EOS 数量也增多,未正确屏蔽会让有效样本比例降低,梯度被更多“无意义位置”稀释,学习更慢更稳,但学的是错目标。

Debug过程

1️⃣ 打印“对齐后”的头尾 token

代码语言:python
代码运行次数:0
运行
复制
# 取一个样本看看
print("x[:8] =", labels[0, :8].tolist())
print("y[:8] =", labels[0, :8].tolist())  # 如果 y 就是 x,自检警报

如果你的目标 y 和输入 x 一模一样、没有右移(去掉最后一位、前面补 BOS),那基本就是没做 shift。

2️⃣ 检查“逐位置有效比例”

代码语言:python
代码运行次数:0
运行
复制
mask = (labels != pad_id)         # True=有效
valid_ratio = mask.float().mean().item()
print("valid ratio:", valid_ratio)

低于 0.6 甚至更低时,未经处理的平均会严重稀释梯度。搭配 label_smoothing 时尤其明显。

3️⃣ 定位“PAD 是否被看到”

  • 若使用SDPA:把 attn_mask 可视化;
  • 若用 nn.MultiheadAttention:确认 key_padding_mask=True 表示屏蔽(别用反了)。

见我之前那篇《注意力 Mask 用反》;这里不赘述。

修复方案

1️⃣ 正确的teacher forcing 右移(shift)

代码语言:python
代码运行次数:0
运行
复制
def shift_labels(input_ids, pad_id, bos_id=None):
    # 目标是“预测下一 token”:y_t = x_{t+1}
    # x: [B,T]     y: [B,T],其中 y[..., -1] = PAD(或 EOS 后全 PAD)
    y = input_ids.clone()
    y[..., :-1] = input_ids[..., 1:]
    y[..., -1]  = pad_id
    if bos_id is not None:
        # 可选:若你想显式在 x 开头放 BOS,让第一个预测以 BOS 为条件
        input_ids = input_ids.clone()
        input_ids[..., 0] = bos_id
    return y

2️⃣ 布尔掩码 + 对分母做精确归一(避免被 PAD 稀释)

代码语言:python
代码运行次数:0
运行
复制
import torch, torch.nn.functional as F

def lm_loss_masked(logits, labels, pad_id, label_smoothing=0.0):
    """
    logits: [B,T,V], labels: [B,T]
    - 右移后的 labels
    - 对 PAD 做显式布尔屏蔽
    - 分母用“有效 token 数”而不是 B*T
    """
    B, T, V = logits.shape
    mask = (labels != pad_id)                        # [B,T]  True=有效

    # 展平仅为计算方便
    logits = logits.reshape(B*T, V)
    labels = labels.reshape(B*T)
    mask   = mask.reshape(B*T)

    # reduction='none' 拿到逐位置 loss,再按“有效数”归一
    loss_per = F.cross_entropy(
        logits, labels,
        ignore_index=pad_id,            # 与 mask 重合;为稳妥仍保留
        label_smoothing=label_smoothing,
        reduction="none",
    )

    valid = mask.sum().clamp_min(1)     # 防止全 PAD 的 batch 除零
    loss = (loss_per * mask).sum() / valid
    return loss

为什么不用默认 mean 默认 mean 的分母是所有位置数(或 PyTorch 内部基于 ignore_index 的约定),在大量 PAD 的批次上容易稀释;手动按有效数归一可控性更好,也便于日志统计。

3️⃣ 保证注意力看不到 PAD/未来位

验证与结果

  • 修复后,loss 数千步内明显下降,ppl 接近公开基线;
  • Batch 从 1→64,曲线单调更稳,不再出现“batch 变大更难学”;
  • 可视化注意力:PAD 与未来位权重 ≈ 0;

即插即用debug脚本

代码语言:python
代码运行次数:0
运行
复制
def sanity_check_alignment(input_ids, labels, pad_id):
    # labels 应该等于 input_ids 右移一位(末尾 PAD)
    shifted = input_ids.clone()
    shifted[..., :-1] = input_ids[..., 1:]
    shifted[..., -1]  = pad_id
    assert torch.equal((labels == shifted).all(dim=-1), torch.ones_like(labels[...,0], dtype=torch.bool)), \
        "labels 与 input_ids 未对齐(缺少 shift)"

def sanity_check_loss_shape(logits, labels):
    B, T, V = logits.shape
    assert labels.shape == (B, T), f"labels 形状不对:{labels.shape} vs {(B,T)}"

def log_valid_ratio(labels, pad_id):
    mask = (labels != pad_id)
    r = mask.float().mean().item()
    print(f"[dbg] valid_ratio={r:.3f}")

结语

语言模型训练里,“对齐”和“掩码”是两条看不见却致命的红线。 一旦忘记右移或把 PAD 混进 loss 分母,再强的优化器也帮不了你。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 语言模型 Loss 死活降不下去:[B,T,V] × [B,T] 的“对齐/掩码”两连坑(未 shift + PAD 未屏蔽)排障实录
    • ❓ Bug 现象
    • 📽️ 场景复现
    • Debug过程
    • 修复方案
    • 验证与结果
    • 即插即用debug脚本
    • 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档