在我们进行模型训练的过程中,可能会遇到这么一种情况:从头训练一切正常,但一旦中途断点续训,loss 开始抖、准确率掉、甚至直接发散。数据与代码未改,唯一不同是“加载了上次的模型权重继续训练”。本篇复盘可复现实验、定位方法与可直接落地的保存/恢复模板。
保存为 resume_bug_demo.py,按注释运行两段即可在 CPU 复现。
import argparse, math, os, torch, torch.nn as nn, torch.nn.functional as F
torch.manual_seed(0)
def make_loader(n=4096, bs=64):
X = torch.randn(n, 10)
y = (X[:, 0] + 0.6 * X[:, 1] > 0).long()
ds = torch.utils.data.TensorDataset(X, y)
return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=True, drop_last=True)
class TinyNet(nn.Module):
def __init__(self):
super().__init__()
self.m = nn.Sequential(nn.Linear(10,64), nn.ReLU(), nn.Linear(64,2))
def forward(self, x): return self.m(x)
def build_scheduler(optimizer, total_steps, warmup_steps=50):
def lr_lambda(step):
if step < warmup_steps: return (step + 1) / warmup_steps
t = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return 0.5 * (1 + math.cos(math.pi * min(1.0, t)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def save_ckpt(path, model, optimizer, scheduler, scaler, global_step):
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"scaler": scaler.state_dict() if scaler is not None else None,
"global_step": global_step,
}, path)
def load_ckpt(path, model, optimizer=None, scheduler=None, scaler=None):
ckpt = torch.load(path, map_location="cpu")
model.load_state_dict(ckpt["model"])
if optimizer is not None and "optimizer" in ckpt: optimizer.load_state_dict(ckpt["optimizer"])
if scheduler is not None and "scheduler" in ckpt: scheduler.load_state_dict(ckpt["scheduler"])
if scaler is not None and ckpt.get("scaler") is not None: scaler.load_state_dict(ckpt["scaler"])
return ckpt.get("global_step", 0)
def train(phase, bug, steps):
device = "cpu"
model = TinyNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scaler = None # GPU+AMP 时可替换为 GradScaler()
total_steps = 400
scheduler = build_scheduler(optimizer, total_steps)
loader = make_loader()
it = iter(loader)
global_step = 0
ckpt_path = "ckpts/demo.pt"
if phase == "resume":
if bug:
_ = load_ckpt(ckpt_path, model) # 只加载模型,错误示范
global_step = 0
else:
global_step = load_ckpt(ckpt_path, model, optimizer, scheduler, scaler)
for _ in range(steps):
try: x, y = next(it)
except StopIteration:
it = iter(loader); x, y = next(it)
logits = model(x)
loss = F.cross_entropy(logits, y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
scheduler.step()
global_step += 1
if (global_step % 25) == 0:
lr = optimizer.param_groups[0]["lr"]
acc = (logits.argmax(1) == y).float().mean().item()
tag = f"[{phase}|{'BUG' if bug else 'OK'}]"
print(f"{tag} step={global_step:04d} lr={lr:.4f} loss={loss.item():.3f} acc={acc:.3f}")
if phase == "pretrain":
save_ckpt(ckpt_path, model, optimizer, scheduler, scaler, global_step)
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--phase", choices=["pretrain","resume"], required=True)
ap.add_argument("--bug", choices=["on","off"], default="on")
ap.add_argument("--steps", type=int, default=200)
args = ap.parse_args()
train(args.phase, args.bug=="on", args.steps)
运行方式
# 阶段1:从头训练并保存检查点
python resume_bug_demo.py --phase pretrain --steps 200
# 阶段2a:错误续训(只加载模型权重)
python resume_bug_demo.py --phase resume --bug on --steps 200
# 阶段2b:正确续训(加载模型+优化器+调度器+scaler+全局步)
python resume_bug_demo.py --phase resume --bug off --steps 200
1️⃣ 检查优化器与调度器状态
打印优化器 state 是否为空、调度器 last_epoch 是否连续。
print("optimizer_has_state:", any(len(s)>0 for s in optimizer.state.values()))
print("scheduler_last_epoch:", getattr(scheduler, "last_epoch", None))
print("global_step:", global_step)
2️⃣ 记录学习率和损失
在 resume 的前 100 步高频打印 lr 与 loss,若 lr 不连续或 loss 突升,优先怀疑未恢复状态与步数。
3️⃣ AMP 的数值检查
半精度训练中,恢复后打印 scaler.get_scale(),若恢复为默认初值且随即出现溢出/underflow,需要同步加载 scaler 状态。
1️⃣ 保存检查点时同步写入所有训练态
state = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict() if scheduler else None,
"scaler": scaler.state_dict() if scaler else None,
"global_step": global_step,
"epoch": epoch,
}
torch.save(state, ckpt_path)
2️⃣ 恢复时按照先构建后加载的顺序加载全部状态
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state["model"])
if optimizer and state.get("optimizer"): optimizer.load_state_dict(state["optimizer"])
if scheduler and state.get("scheduler"): scheduler.load_state_dict(state["scheduler"])
if scaler and state.get("scaler"): scaler.load_state_dict(state["scaler"])
global_step = state.get("global_step", 0)
epoch = state.get("epoch", 0)
3️⃣ 训练循环里以 global_step 为唯一驱动
将日志、评估、保存、调度等触发条件统一用 global_step,避免 resume 后 epoch 边界重复或跳过。
if (global_step % log_every) == 0: ...
if (global_step % eval_every) == 0: ...
if (global_step % save_every) == 0: ...
建议基于优化步数而非 batch 数或 epoch 数计算。若启用梯度累积,应使用 (len(dataloader) // accum_steps) × epochs。resume 时保持 total_steps 不变,并恢复 scheduler 的内部步数。
若采用每步调度,应在完成一次优化步时再 step 调度器,且 resume 后 last_epoch 与累计的优化步对齐。
强烈建议用 global_step 做所有触发条件,resume 后自然对齐;若用 ep och 边界,请保存 last_batch_idx 并在恢复时跳过已完成的 batch。
DDP/FSDP 下通常在 rank0 保存,恢复时先构建并 wra p 模型,再加载 state_dict。FSDP 建议使用库提供的全局一致性 checkpoint 接口。
断点续训不是“从当前 loss 继续”,而是“从当前优化动力学继续”。只加载模型权重相当于丢掉了动量、学习率位置与半精度缩放的全部历史信息,难免出现曲线回退与不稳定。把优化器、调度器、GradScaler 与 global_step 一并纳入检查点模板,并在恢复时做一次完整的自检,续训就能与从头训练保持一致的轨迹与表现。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。