在深度学习的二分类/多标签项目,超参数的选择至关重要。通常来讲,越大的batch模型在训练的过程中可以看见更多的样本数据,从而达到越稳定的训练效果。但是,本人在做训练的时候,偶尔会出现 batch=1 一切正常,batch≥8 迅速发散。这种情况也并不少见,因此,为了能够方便初学者的学习,笔者将自己的debug过程以及修改结果写在这篇文章中。
loss≈0.69 → 0.1x
。loss
显著偏大、梯度爆炸或收敛极慢;grad_norm
随 batch 增大而非线性变大。loss(reduction='none')
发现形状是 B, B(或 B, B, …),而预期应为 B 或 B, C。import torch, torch.nn as nn, torch.nn.functional as F
B = 8
logits = torch.randn(B, 1) # ✅ 模型输出 [B, 1]
labels = torch.randint(0, 2, (B,)) # ❌ 标签 [B],与 [B,1] 不同形状
loss_fn = nn.BCEWithLogitsLoss(reduction='none')
loss_mat = loss_fn(logits, labels.float()) # ⛔ 实际得到 [B, B] 的矩阵
print(loss_mat.shape) # torch.Size([8, 8]) —— 被广播了
PyTorch 广播从尾维度对齐:
logits
:[B, 1]
labels
:[B]
→ 视作 [1, B]
(1 ↔ B)
、(B ↔ 1)
都可广播,得到 [B, B]
。损失被无意义地扩成矩阵,梯度混进“交叉样本”项里,学习直接跑偏。注:batch=1 时
[1,1] × [1]
广播不会放大,因而似乎看起来没问题,代码也没有报错,这也是这个坑极难被发现的原因。
1️⃣ 打印逐样本损失形状
loss_fn = nn.BCEWithLogitsLoss(reduction='none')
loss_per = loss_fn(logits, labels.float())
print(loss_per.shape) # 期望 [B] 或 [B, 1];若是 [B, B] 就错了
2️⃣ 开启异常/检查广播
torch.backends.cuda.matmul.allow_tf32 = False # 与本问题无关,但便于稳定复现
# 关键是看 loss_per.shape 和 logits/labels 的形状是否严格相等
assert logits.shape == labels.shape, f"shape mismatch: {logits.shape} vs {labels.shape}"
3️⃣ 查调用栈:谁在 squeeze/unsqueeze
labels = labels.squeeze()
把 [B,1]
挤成 [B]
;logits = logits.squeeze(-1)
导致与标签不匹配。collate_fn
也可能把 [B,1]
合成 [B]
。1️⃣ 二分类(单通道)——两者都用 B,1 或 B,但要严格一致
# 方案 A:都用 [B,1]
logits = model(x) # [B,1]
labels = labels.float().view(-1, 1) # [B,1]
loss = F.binary_cross_entropy_with_logits(logits, labels)
# 方案 B:都用 [B]
logits = model(x).squeeze(-1) # [B]
labels = labels.float().view(-1) # [B]
loss = F.binary_cross_entropy_with_logits(logits, labels)
2️⃣ 多标签(C 类独立二分类)
logits = model(x) # [B, C]
labels = labels.float() # [B, C](独热或 0/1 多热)
assert logits.shape == labels.shape
loss = F.binary_cross_entropy_with_logits(logits, labels)
3️⃣ 多类单选(softmax 分类)
不要用 BCE,而是:
logits = model(x) # [B, C]
targets = labels.long().view(-1) # [B],类别索引
loss = F.cross_entropy(logits, targets)
loss(reduction='none')
形状回到 [B]
(或 [B,C]
),grad_norm
随 batch 合理缩放;很多“玄学发散”并不是优化器/学习率的问题,而是广播在背后捣鬼,而广播机制也是时好时坏,一方面可以方便我们进行计算,但同时,广播也成为了最难发现的bug问题,因为广播机制的存在导致了原来有错误的代码也可以正常运行。这种bug最终定位为 BCEWithLogitsLoss
输入形状不匹配 导致隐式广播:logits
形状 [B,1]
与 labels
形状 [B]
在 PyTorch 中会广播成 [B,B]
再逐元素求损失,等价于把每个样本和所有标签两两配对,直接把损失放大且掺杂错误梯度。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。