首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >【Debug日志 | 断点训练异常】

【Debug日志 | 断点训练异常】

原创
作者头像
九年义务漏网鲨鱼
发布2025-09-18 15:31:37
发布2025-09-18 15:31:37
2100
代码可运行
举报
文章被收录于专栏:tencent cloudtencent cloud
运行总次数:0
代码可运行

断点续训越来越差?未恢复优化器/调度器/GradScaler 状态导致的收敛倒退

在我们进行模型训练的过程中,可能会遇到这么一种情况:从头训练一切正常,但一旦中途断点续训,loss 开始抖、准确率掉、甚至直接发散。数据与代码未改,唯一不同是“加载了上次的模型权重继续训练”。本篇复盘可复现实验、定位方法与可直接落地的保存/恢复模板。

❓ Bug 现象

  • 断点续训后,学习率曲线突然跳变(回到 warmup 高位或峰值附近)。
  • 同等步数下,loss 明显高于从头训练,准确率短时回退。
  • AMP 训练里偶发首批 NaN,或几步内 loss 急剧波动。
  • 打印优化器状态字典发现为空,或者调度器 last_epoch 与预期不符。

📽️ 场景复现

保存为 resume_bug_demo.py,按注释运行两段即可在 CPU 复现。

代码语言:python
代码运行次数:0
运行
复制
import argparse, math, os, torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)

def make_loader(n=4096, bs=64):
    X = torch.randn(n, 10)
    y = (X[:, 0] + 0.6 * X[:, 1] > 0).long()
    ds = torch.utils.data.TensorDataset(X, y)
    return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, drop_last=True)

class TinyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = nn.Sequential(nn.Linear(10,64), nn.ReLU(), nn.Linear(64,2))
    def forward(self, x): return self.m(x)

def build_scheduler(optimizer, total_steps, warmup_steps=50):
    def lr_lambda(step):
        if step < warmup_steps: return (step + 1) / warmup_steps
        t = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return 0.5 * (1 + math.cos(math.pi * min(1.0, t)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def save_ckpt(path, model, optimizer, scheduler, scaler, global_step):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "scaler": scaler.state_dict() if scaler is not None else None,
        "global_step": global_step,
    }, path)

def load_ckpt(path, model, optimizer=None, scheduler=None, scaler=None):
    ckpt = torch.load(path, map_location="cpu")
    model.load_state_dict(ckpt["model"])
    if optimizer is not None and "optimizer" in ckpt: optimizer.load_state_dict(ckpt["optimizer"])
    if scheduler is not None and "scheduler" in ckpt: scheduler.load_state_dict(ckpt["scheduler"])
    if scaler is not None and ckpt.get("scaler") is not None: scaler.load_state_dict(ckpt["scaler"])
    return ckpt.get("global_step", 0)

def train(phase, bug, steps):
    device = "cpu"
    model = TinyNet().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    scaler = None  # GPU+AMP 时可替换为 GradScaler()
    total_steps = 400
    scheduler = build_scheduler(optimizer, total_steps)

    loader = make_loader()
    it = iter(loader)
    global_step = 0
    ckpt_path = "ckpts/demo.pt"

    if phase == "resume":
        if bug:
            _ = load_ckpt(ckpt_path, model)   # 只加载模型,错误示范
            global_step = 0
        else:
            global_step = load_ckpt(ckpt_path, model, optimizer, scheduler, scaler)

    for _ in range(steps):
        try:   x, y = next(it)
        except StopIteration:
            it = iter(loader); x, y = next(it)
        logits = model(x)
        loss = F.cross_entropy(logits, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()
        global_step += 1

        if (global_step % 25) == 0:
            lr = optimizer.param_groups[0]["lr"]
            acc = (logits.argmax(1) == y).float().mean().item()
            tag = f"[{phase}|{'BUG' if bug else 'OK'}]"
            print(f"{tag} step={global_step:04d} lr={lr:.4f} loss={loss.item():.3f} acc={acc:.3f}")

    if phase == "pretrain":
        save_ckpt(ckpt_path, model, optimizer, scheduler, scaler, global_step)

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--phase", choices=["pretrain","resume"], required=True)
    ap.add_argument("--bug", choices=["on","off"], default="on")
    ap.add_argument("--steps", type=int, default=200)
    args = ap.parse_args()
    train(args.phase, args.bug=="on", args.steps)

运行方式

代码语言:python
代码运行次数:0
运行
复制
# 阶段1:从头训练并保存检查点
python resume_bug_demo.py --phase pretrain --steps 200

# 阶段2a:错误续训(只加载模型权重)
python resume_bug_demo.py --phase resume --bug on --steps 200

# 阶段2b:正确续训(加载模型+优化器+调度器+scaler+全局步)
python resume_bug_demo.py --phase resume --bug off --steps 200

后果

  • 错误续训时 lr 会回到 warmup 峰值或高位,loss 短时恶化,acc 降低。
  • 正确续训时 lr 连续,loss/acc 曲线平滑延续 pretrain 末尾的趋势。
  • 若使用 AMP 且未恢复 GradScaler,首批更容易出现溢出或梯度无效。

Debug 过程

1️⃣ 检查优化器与调度器状态

打印优化器 state 是否为空、调度器 last_epoch 是否连续。

代码语言:python
代码运行次数:0
运行
复制
print("optimizer_has_state:", any(len(s)>0 for s in optimizer.state.values()))
print("scheduler_last_epoch:", getattr(scheduler, "last_epoch", None))
print("global_step:", global_step)

2️⃣ 记录学习率和损失

在 resume 的前 100 步高频打印 lr 与 loss,若 lr 不连续或 loss 突升,优先怀疑未恢复状态与步数。

3️⃣ AMP 的数值检查

半精度训练中,恢复后打印 scaler.get_scale(),若恢复为默认初值且随即出现溢出/underflow,需要同步加载 scaler 状态。

代码修改

1️⃣ 保存检查点时同步写入所有训练态

代码语言:python
代码运行次数:0
运行
复制
state = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict() if scheduler else None,
    "scaler": scaler.state_dict() if scaler else None,
    "global_step": global_step,
    "epoch": epoch,
}
torch.save(state, ckpt_path)

2️⃣ 恢复时按照先构建后加载的顺序加载全部状态

代码语言:python
代码运行次数:0
运行
复制
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state["model"])
if optimizer and state.get("optimizer"): optimizer.load_state_dict(state["optimizer"])
if scheduler and state.get("scheduler"): scheduler.load_state_dict(state["scheduler"])
if scaler and state.get("scaler"): scaler.load_state_dict(state["scaler"])
global_step = state.get("global_step", 0)
epoch = state.get("epoch", 0)

3️⃣ 训练循环里以 global_step 为唯一驱动

将日志、评估、保存、调度等触发条件统一用 global_step,避免 resume 后 epoch 边界重复或跳过。

代码语言:python
代码运行次数:0
运行
复制
if (global_step % log_every) == 0: ...
if (global_step % eval_every) == 0: ...
if (global_step % save_every) == 0: ...

Q & A

  • OneCycleLR 或 CosineAnnealingLR 的 total_steps/T_max 该怎么设

建议基于优化步数而非 batch 数或 epoch 数计算。若启用梯度累积,应使用 (len(dataloader) // accum_steps) × epochs。resume 时保持 total_steps 不变,并恢复 scheduler 的内部步数。

  • 梯度累积是否影响调度步进

若采用每步调度,应在完成一次优化步时再 step 调度器,且 resume 后 last_epoch 与累计的优化步对齐。

  • 断点落点在 epoch 中间怎么办

强烈建议用 global_step 做所有触发条件,resume 后自然对齐;若用 ep och 边界,请保存 last_batch_idx 并在恢复时跳过已完成的 batch。

  • 分布式训练如何保存与恢复

DDP/FSDP 下通常在 rank0 保存,恢复时先构建并 wra p 模型,再加载 state_dict。FSDP 建议使用库提供的全局一致性 checkpoint 接口。

结语

断点续训不是“从当前 loss 继续”,而是“从当前优化动力学继续”。只加载模型权重相当于丢掉了动量、学习率位置与半精度缩放的全部历史信息,难免出现曲线回退与不稳定。把优化器、调度器、GradScaler 与 global_step 一并纳入检查点模板,并在恢复时做一次完整的自检,续训就能与从头训练保持一致的轨迹与表现。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 断点续训越来越差?未恢复优化器/调度器/GradScaler 状态导致的收敛倒退
    • ❓ Bug 现象
    • 📽️ 场景复现
    • 后果
    • Debug 过程
    • 代码修改
    • Q & A
    • 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档