all_gather
导致 label/pred 错位在本人的实践操作中,多卡训练时,验证 AUC/AP 时高时低,甚至比单卡差一截;换种 batch_size 或改
drop_last
后曲线又“起飞”。让本来就为黑盒模型的深度学习更加黑盒。因此,本章结合自己的debug经验来讲解。
drop_last
、batch_size
,曲线形态改变但仍不稳比较常见的错误写法:直接
all_gather
每步的pred
/label
,忽视尾批大小不同与拼接顺序。
import torch, torch.distributed as dist
def gather_step_wrong(pred, label):
# pred: [B, 1], label: [B]
ws = dist.get_world_size()
pred_list = [torch.zeros_like(pred) for _ in range(ws)]
label_list = [torch.zeros_like(label) for _ in range(ws)]
# ❌ 直接 all_gather:要求每个 rank 张量同形状,否则会截断/复用缓存
dist.all_gather(pred_list, pred) # <-- 尾批 B 不同 => 错位
dist.all_gather(label_list, label)
pred_all = torch.cat(pred_list, dim=0)
label_all = torch.cat(label_list, dim=0)
return pred_all, label_all
可能的触发条件
1️⃣ drop_last=Fal
se 且 len(datase
t) 不是 world_si
ze 的整数倍 → 各 rank 尾批
B 不同。
2️⃣ 验证集中过滤/采样导致每卡步数不同。
3️⃣ 步内先 shuff
le 再同步导致拼接顺序不一致(极端情况下)。
4️⃣ 用 torchmetrics.AUR
OC 时同时在 step 与 epoch 做同步,导致重复/错序聚合。
1️⃣ 二分——各自计算 vs. 全局计算
pred/label
收齐后再算,全局 AUC 异常 → 问题在汇总阶段。2️⃣ 长度与顺序核查(Cursor 自动标注)
在 gather
后打印:
print(rank, pred.shape[0], label.shape[0])
pred_all.size(0) != label_all.size(0)
(典型错位信号)。3️⃣ 还原“错位”机理(ChatGPT 解释)
all_gather
要求各 rank 张量形状一致;若尾批大小不同,常见“权宜之计”是预分配最大长度再 all_gather
,却忘了按各自真实长度截断,从而把padding也当成真实样本。4️⃣ 复现 MRE(Codex 生成)
Codex 生成了一个 2 卡、不同尾批的假数据脚本,一跑即现“全局 AUC 飘忽”的现象,锁定根因。
关键点:先同步各 rank 的真实长度 → 按 max_len 进行 padding →
all_gather
→ 按长度回切 → 在 rank0 统一拼接并计算。同时固定全局顺序(按global_offset + local_index
)。
# ddp_metric_gather.py —— 可直接复用
import torch, torch.distributed as dist
def gather_varlen_tensor(x: torch.Tensor, dim=0):
"""
变长安全 all_gather:返回 rank0 上拼接后的张量;其他 rank 返回 None
"""
assert x.is_cuda, "put tensors on CUDA for NCCL"
world = dist.get_world_size()
rank = dist.get_rank()
# 1) 同步各自真实长度
len_local = torch.tensor([x.size(dim)], device=x.device, dtype=torch.int64)
lens = [torch.zeros_like(len_local) for _ in range(world)]
dist.all_gather(lens, len_local)
lens = torch.stack(lens).squeeze(-1) # [world]
max_len = int(lens.max().item())
# 2) 按 max_len padding 到同形状
pad_shape = list(x.shape)
pad_shape[dim] = max_len - x.size(dim)
pad = torch.zeros(pad_shape, device=x.device, dtype=x.dtype)
x_pad = torch.cat([x, pad], dim=dim)
# 3) all_gather 到各自的缓冲区
gather_list = [torch.zeros_like(x_pad) for _ in range(world)]
dist.all_gather(gather_list, x_pad)
# 4) 仅在 rank0 回切并拼接(按 lens 截断)
if rank == 0:
parts = []
for r in range(world):
end = int(lens[r].item())
slc = [slice(None)] * x.dim()
slc[dim] = slice(0, end)
parts.append(gather_list[r][tuple(slc)])
return torch.cat(parts, dim=dim)
else:
return None
@torch.no_grad()
def gather_preds_labels(pred: torch.Tensor, label: torch.Tensor):
# pred [B,1] / [B,C];label [B] / [B,C]
pred_all = gather_varlen_tensor(pred, dim=0)
label_all = gather_varlen_tensor(label, dim=0)
# 仅 rank0 计算指标,其他 rank 返回 None
if dist.get_rank() == 0:
return pred_all.detach().cpu(), label_all.detach().cpu()
return None, None
使用方式(验证/评估阶段):
model.eval()
with torch.inference_mode():
preds_local, labels_local = [], []
for batch in val_loader:
x, y = batch["img"].cuda(non_blocking=True), batch["label"].cuda(non_blocking=True)
logits = model(x)
prob = torch.sigmoid(logits).squeeze(-1) # [B]
preds_local.append(prob)
labels_local.append(y.float())
pred = torch.cat(preds_local, dim=0)
lab = torch.cat(labels_local, dim=0)
# 变长安全汇总
pred_all, lab_all = gather_preds_labels(pred, lab)
if dist.get_rank() == 0:
from sklearn.metrics import roc_auc_score, average_precision_score
auc = roc_auc_score(lab_all.numpy(), pred_all.numpy())
ap = average_precision_score(lab_all.numpy(), pred_all.numpy())
print(f"[Global] AUC={auc:.4f} AP={ap:.4f}")
最后定位是 分布式汇总指标时,label 与 pred 在 all_gather
后发生错位(不同 rank 的尾批大小不同、或拼接顺序不一致),造成以错配数据计算的 AUC。本文完整复盘,并给出可直接复用的“变长安全汇总模板”
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。