pos_weight
/ weight
用错场景:多标签二分类(医学/安全/广告审核都常见),正样本极少。训练后验证集 AUC 看着还行,但 F1/Recall 极低,模型几乎“全猜 0”。我一度以为是特征不行,但后面通过AI发现并不是。本章主要总结了常见的现象以及解决方案。
sigmoid(logits)
的分布严重左移(大多 < 0.05),阈值 0.5 几乎没人过线。pos_weight=tensor([class_weights])
,自以为在“强调正例”,实际把它压得更惨。import torch, torch.nn.functional as F
B, C = 32, 5
logits = torch.randn(B, C) # 模拟模型输出
labels = (torch.rand(B, C) < torch.tensor([0.01,0.02,0.05,0.1,0.2])).float() # 极不平衡
# ❌ 错 1:把 pos_weight 写成了“正/负比”(应为 负/正)
pos = labels.sum(dim=0) + 1e-6
neg = labels.numel()/C - pos
wrong_pos_weight = (pos / neg).clamp(min=1e-3) # ❌
# ❌ 错 2:还叠加了 element-wise 的 weight(双重压制负例/搞乱标度)
elem_weight = torch.ones_like(labels) * 0.5 # ❌ 随手设
loss = F.binary_cross_entropy_with_logits(
logits, labels,
pos_weight=wrong_pos_weight, # 形状 [C]
weight=elem_weight, # 形状 [B,C]
reduction='mean'
)
print('loss=', loss.item())
后果
pos_weight
只放大正样本项(−ylogσ(x)-y\log \sigma(x)−ylogσ(x) 部分),如果把数值算反(<1),等于进一步减小正样本的梯度;weight
一个 <1 的系数,会同时缩小正负项,整体标度混乱;logits
往负无穷推(更偏向全 0),阈值 0.5 完全“过不去”。1️⃣ 打印每类正例比例 & 估算 pos_weight
with torch.no_grad():
pos = labels.sum(dim=0)
neg = labels.size(0) - pos
print('pos ratio:', (pos / labels.size(0)).tolist())
print('expected pos_weight (neg/pos):', (neg / pos.clamp_min(1)).tolist())
发现我用的是 pos/neg 而非 neg/pos,难怪梯度“偏心”。
2️⃣ 看概率分布 & 阈值通过率
prob = torch.sigmoid(logits)
print('mean prob per class:', prob.mean(dim=0).tolist())
print('pass@0.5 per class:', (prob>0.5).float().mean(dim=0).tolist())
大量类别
pass@0.5 ≈ 0
,说明阈值不合适或模型已被“压扁”。
3️⃣ 检查 weig
ht 是否真的需要
绝大多数多标签不平衡场景只需
pos_weight
;weight
是逐元素权重,同时缩放正负项,少用。
1️⃣ 正确计算 pos_weight = N_neg / N_p
os(按类别)
def compute_pos_weight(train_labels: torch.Tensor) -> torch.Tensor:
"""
train_labels: [N, C] in {0,1}
return: [C] tensor on CPU,后续 to(device)
"""
pos = train_labels.sum(dim=0) # [C]
neg = train_labels.size(0) - pos
# 处理“0 正例”类别:给一个温和的上限,避免 inf
pw = (neg / pos.clamp_min(1)).clamp(max=100.0)
return pw
# 用法
pos_weight = compute_pos_weight(train_labels)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device), reduction='mean')
只传
pos_weight
,不要同时传weight
(除非你非常清楚要做按样本/位置加权)。 不要每个 batch 动态重算,用全数据统计(或滑动估计),否则训练不稳定。
2️⃣ 用 BCEWithLogitsLo
ss(不是 Sigmo
id + BCELo
ss)
logits = model(x) # [B,C]
loss = criterion(logits, labels) # labels float in {0,1}
数值更稳;如果你手搓
sigmoid
再BCELoss
,容易在极端概率处下溢/溢出。
3️⃣ 阈值不要固定 0.5:做验证集阈值搜索(全局或分类别)
import numpy as np
from sklearn.metrics import f1_score
def search_thresholds(probs, labels, per_class=True):
"""
probs/labels: numpy array [N,C]
return: thresholds [C] or scalar
"""
if not per_class:
grid = np.linspace(0.01, 0.99, 99)
f1 = [f1_score(labels.ravel(), (probs.ravel()>t).astype(int)) for t in grid]
return float(grid[int(np.argmax(f1))])
C = probs.shape[1]
th = np.zeros(C)
for c in range(C):
grid = np.linspace(0.01, 0.99, 99)
f1 = [f1_score(labels[:,c], (probs[:,c]>t).astype(int)) for t in grid]
th[c] = grid[int(np.argmax(f1))]
return th
也可以按 Youden J、F1、成本敏感(precision/recall 约束)定制目标;阈值每次在验证集重估,不泄漏测试集。
# 1) 准备损失
pos_weight = compute_pos_weight(train_labels).to(device)
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='mean')
# 2) 训练
for x, y in train_loader:
x, y = x.to(device), y.to(device).float()
logits = model(x)
loss = criterion(logits, y)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# 3) 验证(阈值搜索)
probs_val, labels_val = [], []
with torch.no_grad():
for x, y in val_loader:
p = torch.sigmoid(model(x.to(device))).cpu()
probs_val.append(p); labels_val.append(y)
probs_val = torch.cat(probs_val).numpy()
labels_val = torch.cat(labels_val).numpy()
thr = search_thresholds(probs_val, labels_val, per_class=True)
# 4) 报告指标
pred = (probs_val > thr).astype(int)
print_metrics(pred, labels_val) # 自己实现一个包含 per-class 的表
pos_weight
和 weight
区别?pos_weight [C]
:只放大正样本项;weight [B,C]
或 [C]
:正负项同时乘权,多用于样本级/位置级权重。pos_weight
怎么取值?N_neg / N_pos
;可加裁剪(max=100
)防止梯度爆炸。pos_weight
正确 + 阈值校准,通常已够。确认 labels.float()
∈ {0,1},没有“软标签”误用到 BCE(若有,用 BCE 也行,但要清楚语义)。
在高度不平衡的多标签任务里,损失定义 + 阈值选择决定了你到底在学什么。 把 pos_weight
算对、用对,把阈值在验证集上系统地校准,最后定位到三件事:pos_weight
与 weight
的语义弄反;pos_weight
数值算错(用“正/负比”而不是“负/正比”);阈值固死 0.5,在极度不平衡时**没校准。本文记录复盘过程、给出权重正确计算模板与阈值搜索脚本,并附带几个“看日志就能发现”的自检点。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。