当我们在 4 卡 DDP 上训练一个图像分类模型,每张卡的显存几乎快溢出了,训练 loss 似乎在降,但 val acc 抖动剧烈、收敛很慢;切回单卡或把 batch 做大就好很多。
import torch, torch.nn as nn, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision.models import resnet18
def main():
dist.init_process_group("nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
model = resnet18(num_classes=10).to(device) # 自带 BN
model = DDP(model, device_ids=[device.index])
optim = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
loader = tiny_loader(bs_per_gpu=2) # 每卡只有 2
for epoch in range(5):
model.train()
for x,y in loader:
x, y = x.to(device), y.to(device)
out = model(x)
loss = nn.CrossEntropyLoss()(out, y)
optim.zero_grad(); loss.backward(); optim.step()
# eval:acc 大幅抖动
model.eval()
使用 running_mean / running_var(训练期累积的统计量)。这些统计量也被上面的小批噪声污染 → train/val 分布错位。1️⃣ 确认是 BN 问题而非优化器
def set_bn_eval(m):
if isinstance(m, nn.modules.batchnorm._BatchNorm):
m.eval()
model.apply(set_bn_eval)
2️⃣ 观察 BN 统计的“噪声”
running_mean/var
变化幅度,或与全局数据均值对比。running_mean
,发现彼此差异很大。3️⃣ 验证“同步BN”能否改善
1️⃣ 用 SyncBatchNorm 同步多卡统计(推荐)
# 在构建 DDP 之前转换
model = torchvision.models.resnet50(num_classes=...) # 或你自己的模型
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.to(device)
model = DDP(model, device_ids=[device.index], broadcast_buffers=True) # 保持默认就可
注意
2️⃣ 小批量改GroupNorm / LayerNorm(结构替代)
当每卡 batch 长期开很小(≤4)时,建议结构性替代 BN:
# 把 2D BN 换成 GN(如 32 组)
def bn_to_gn(module, num_groups=32):
for name, m in module.named_children():
if isinstance(m, nn.BatchNorm2d):
gn = nn.GroupNorm(num_groups, m.num_features, affine=True)
setattr(module, name, gn)
else:
bn_to_gn(m, num_groups)
bn_to_gn(model)
经验
3️⃣ PreciseBN:在更大/更多数据上重估 running stats(
当你必须用 BN,但每卡很小,可在每个 epoch 结束后跑一遍 统计校准。
@torch.no_grad()
def precise_bn(model, data_loader, num_batches=200, device="cuda"):
# 暂时切回 train,使 BN 更新 running stats,但不做反传
was_training = model.training
model.train()
# 清空累计
for m in model.modules():
if isinstance(m, nn.modules.batchnorm._BatchNorm):
m.running_mean.zero_(); m.running_var.fill_(1)
m.num_batches_tracked.zero_()
it = iter(data_loader)
for _ in range(num_batches):
try: x, _ = next(it)
except StopIteration: it = iter(data_loader); x,_ = next(it)
model(x.to(device))
model.train(was_training)
多卡小批训练时,BatchNorm 很容易成为“隐形噪声放大器”。把 SyncBN设为默认,把PreciseBN/GN当作可靠后手,再配一个小脚本长期体检,你的收敛曲线会从“地震图”变回“阶梯线”。最终定位为:BatchNorm 在小 batch + 多卡场景下统计量严重失真(每卡只看见 2–4 张图、各卡统计不一致),导致训练/验证分布错位。本文记录完整排障过程与修复方案,并给出可复用的检测与修复代码。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。