首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >【DEBUG 日志 | 分布式评估 AUC 乱飞】

【DEBUG 日志 | 分布式评估 AUC 乱飞】

原创
作者头像
九年义务漏网鲨鱼
发布2025-09-06 14:33:08
发布2025-09-06 14:33:08
24100
代码可运行
举报
运行总次数:0
代码可运行

分布式评估 AUC 乱飞:DDP all_gather 导致 label/pred 错位

在本人的实践操作中,多卡训练时,验证 AUC/AP 时高时低,甚至比单卡差一截;换种 batch_size 或改 drop_last 后曲线又“起飞”。让本来就为黑盒模型的深度学习更加黑盒。因此,本章结合自己的debug经验来讲解。

❓ Bug 现象

  • 单卡 GPU:AUC 稳定在 0.86±0.01
  • 双卡 DDP 分布式训练:AUC 在 0.62~0.91 抖动;改 drop_lastbatch_size,曲线形态改变但仍不稳
  • 虽然打印每卡本地 AUC 正常;但是做“全局汇总再算”就异常

📽️ 场景复现

比较常见的错误写法:直接 all_gather 每步的 pred/label,忽视尾批大小不同与拼接顺序。

代码语言:python
代码运行次数:0
运行
复制
import torch, torch.distributed as dist

def gather_step_wrong(pred, label):
    # pred: [B, 1], label: [B]
    ws = dist.get_world_size()
    pred_list  = [torch.zeros_like(pred)  for _ in range(ws)]
    label_list = [torch.zeros_like(label) for _ in range(ws)]

    # ❌ 直接 all_gather:要求每个 rank 张量同形状,否则会截断/复用缓存
    dist.all_gather(pred_list,  pred)     # <-- 尾批 B 不同 => 错位
    dist.all_gather(label_list, label)

    pred_all  = torch.cat(pred_list,  dim=0)
    label_all = torch.cat(label_list, dim=0)
    return pred_all, label_all

可能的触发条件

1️⃣ drop_last=False 且 len(dataset) 不是 world_size 的整数倍 → 各 rank 尾 B 不同。

2️⃣ 验证集中过滤/采样导致每卡步数不同。

3️⃣ 步内先 shuffle 再同步导致拼接顺序不一致(极端情况下)。

4️⃣ 用 torchmetrics.AUROC同时在 step 与 epoch 做同步,导致重复/错序聚合。

Debug过程

1️⃣ 二分——各自计算 vs. 全局计算

  • 在每个 rank 上各自算 AUC(仅用本地数据),数值正常。
  • 把各 rank 的 pred/label 收齐后再算,全局 AUC 异常 → 问题在汇总阶段

2️⃣ 长度与顺序核查(Cursor 自动标注)

gather 后打印:

代码语言:python
代码运行次数:0
运行
复制
print(rank, pred.shape[0], label.shape[0])
  • 发现拼接后样本数对不上;有时 pred_all.size(0) != label_all.size(0)(典型错位信号)。

3️⃣ 还原“错位”机理(ChatGPT 解释)

  • all_gather 要求各 rank 张量形状一致;若尾批大小不同,常见“权宜之计”是预分配最大长度all_gather却忘了按各自真实长度截断,从而把padding也当成真实样本。
  • 即使长度对上,顺序也可能不一致(例如 rank1 的第 i 个样本在全局排序后不在 rank0 对应位置)。

4️⃣ 复现 MRE(Codex 生成)

Codex 生成了一个 2 卡、不同尾批的假数据脚本,一跑即现“全局 AUC 飘忽”的现象,锁定根因。

调整代码

关键点:先同步各 rank 的真实长度按 max_len 进行 paddingall_gather按长度回切在 rank0 统一拼接并计算。同时固定全局顺序(按 global_offset + local_index)。

代码语言:python
代码运行次数:0
运行
复制
# ddp_metric_gather.py —— 可直接复用
import torch, torch.distributed as dist

def gather_varlen_tensor(x: torch.Tensor, dim=0):
    """
    变长安全 all_gather:返回 rank0 上拼接后的张量;其他 rank 返回 None
    """
    assert x.is_cuda, "put tensors on CUDA for NCCL"
    world = dist.get_world_size()
    rank  = dist.get_rank()

    # 1) 同步各自真实长度
    len_local = torch.tensor([x.size(dim)], device=x.device, dtype=torch.int64)
    lens = [torch.zeros_like(len_local) for _ in range(world)]
    dist.all_gather(lens, len_local)
    lens = torch.stack(lens).squeeze(-1)  # [world]
    max_len = int(lens.max().item())

    # 2) 按 max_len padding 到同形状
    pad_shape = list(x.shape)
    pad_shape[dim] = max_len - x.size(dim)
    pad = torch.zeros(pad_shape, device=x.device, dtype=x.dtype)
    x_pad = torch.cat([x, pad], dim=dim)

    # 3) all_gather 到各自的缓冲区
    gather_list = [torch.zeros_like(x_pad) for _ in range(world)]
    dist.all_gather(gather_list, x_pad)

    # 4) 仅在 rank0 回切并拼接(按 lens 截断)
    if rank == 0:
        parts = []
        for r in range(world):
            end = int(lens[r].item())
            slc = [slice(None)] * x.dim()
            slc[dim] = slice(0, end)
            parts.append(gather_list[r][tuple(slc)])
        return torch.cat(parts, dim=dim)
    else:
        return None

@torch.no_grad()
def gather_preds_labels(pred: torch.Tensor, label: torch.Tensor):
    # pred [B,1] / [B,C];label [B] / [B,C]
    pred_all  = gather_varlen_tensor(pred,  dim=0)
    label_all = gather_varlen_tensor(label, dim=0)
    # 仅 rank0 计算指标,其他 rank 返回 None
    if dist.get_rank() == 0:
        return pred_all.detach().cpu(), label_all.detach().cpu()
    return None, None

使用方式(验证/评估阶段)

代码语言:python
代码运行次数:0
运行
复制
model.eval()
with torch.inference_mode():
    preds_local, labels_local = [], []
    for batch in val_loader:
        x, y = batch["img"].cuda(non_blocking=True), batch["label"].cuda(non_blocking=True)
        logits = model(x)
        prob   = torch.sigmoid(logits).squeeze(-1)  # [B]
        preds_local.append(prob)
        labels_local.append(y.float())

    pred = torch.cat(preds_local, dim=0)
    lab  = torch.cat(labels_local, dim=0)

    # 变长安全汇总
    pred_all, lab_all = gather_preds_labels(pred, lab)

    if dist.get_rank() == 0:
        from sklearn.metrics import roc_auc_score, average_precision_score
        auc = roc_auc_score(lab_all.numpy(), pred_all.numpy())
        ap  = average_precision_score(lab_all.numpy(), pred_all.numpy())
        print(f"[Global] AUC={auc:.4f} AP={ap:.4f}")

经验总结

最后定位是 分布式汇总指标时,label 与 pred 在 all_gather 后发生错位(不同 rank 的尾批大小不同、或拼接顺序不一致),造成以错配数据计算的 AUC。本文完整复盘,并给出可直接复用的“变长安全汇总模板”

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 分布式评估 AUC 乱飞:DDP all_gather 导致 label/pred 错位
    • ❓ Bug 现象
    • 📽️ 场景复现
    • Debug过程
    • 调整代码
    • 经验总结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档