在我们在训练一个 Transformer 小模型(中文分类 + 预训练继续训练都试过)。loss 能降一点但很快平台化,验证集准确率一直在 70% 左右“挪不动”。尝试了调 LR/批量/warmup 都不灵。
weight_decay
暂时关掉,或改用 AdamW
+ 合理分组后,曲线立刻恢复“阶梯式上升”import torch, torch.nn as nn
model = tiny_transformer() # 含 Embedding + LayerNorm
opt = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.1)
for step, batch in enumerate(loader):
loss = model(**batch)
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
症状:很快平台,LayerNorm.weight 范数持续下降;embedding 也被明显压小。
1️⃣ 打印参数分组与范数/梯度范数
def stats(model):
import math
g = {}
for n,p in model.named_parameters():
if not p.requires_grad: continue
g[n] = dict(
shape=tuple(p.shape),
norm=float(p.norm().item()),
gnorm=float(p.grad.norm().item()) if p.grad is not None else float('nan')
)
print("LN.gamma example:", {k:v for k,v in g.items() if "LayerNorm.weight" in k and v['norm']})
观察几个 epoch 后 LN/bias 的
norm
是否持续走低、gnorm
是否异常大/小。
2️⃣ 快速 A/B
weight_decay=0
试 1–2k steps:若曲线明显回升,基本锁定 WD 相关。AdamW
(其他不变):若也回升,进一步确认“耦合 L2”的负面效应。1️⃣ AdamW + 参数分组
from torch.optim import AdamW
def build_param_groups(model, wd=0.01, layerwise_lr=None):
no_decay = []
decay = []
for n, p in model.named_parameters():
if not p.requires_grad:
continue
# 规则:bias / 归一化层 / 嵌入层 -> no_decay
if n.endswith(".bias") or "norm" in n.lower() or "layernorm" in n.lower() or "embedding" in n.lower():
no_decay.append(p)
else:
decay.append(p)
groups = [
{"params": decay, "weight_decay": wd},
{"params": no_decay, "weight_decay": 0.0},
]
# 可选:层别学习率(layer-wise LR decay)
if layerwise_lr:
for g in groups:
g["lr"] = layerwise_lr
return groups
groups = build_param_groups(model, wd=0.01)
optimizer = AdamW(groups, lr=3e-4, betas=(0.9, 0.999))
经验值:
wd
0.01~0.1 需视任务;若用了强数据增广/Dropout,可适当调低。 兼容性:DDP/FSDP/AMP/torch.compile 正常;FSDP 下尽量在 wrap 之后构建参数分组。
2️⃣ 仍用 Adam,但采用解耦式衰减
如果必须用 Adam(比如与历史实验对齐),可以手动实现“解耦式”:
opt = torch.optim.Adam(build_param_groups(model, wd=0.0), lr=3e-4)
for step, batch in enumerate(loader):
loss = model(**batch)
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
# 额外做解耦衰减
with torch.no_grad():
for g in opt.param_groups:
wd = g.get("weight_decay", 0.0)
if wd > 0:
for p in g["params"]:
p.mul_(1 - g["lr"] * wd)
3️⃣ 避免“双重衰减”
AdamW(weight_decay=...)
+ 在 loss 里手动加 L2(λ * ||w||²
)LayerNorm.weight
/ bias
范数趋稳,不再被拉向 0。get_parameter_names(model, [nn.LayerNorm])
过滤;核心思路与本文一致。Adam(..., weight_decay=λ)
)是耦合正则:在 梯度里加 λ·w
,再做 Adam 的自适应缩放 → LN、bias 这类“尺度参数”被持续拉小,模型难以维持分布稳定。w
做 w ← w - lr·λ·w
,不进梯度统计,更稳定。bias
LayerNorm.weight
/ BatchNorm.weight
/ RMSNorm.weight
等)很多“怎么都学不动”的 Transformer 实际是权重衰减配置在作祟。最后定位是:把所有参数都做了 L2 正则(Adam(weight_decay=...)
),导致 LayerNorm/Embedding/bias 也被衰减;再叠加“用 Adam + L2(耦合)而非 AdamW(解耦)”,等于双重惩罚关键参数,表现成“怎么调都上不去”。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。