首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >【Debug日志 | device-side assert triggered】

【Debug日志 | device-side assert triggered】

原创
作者头像
九年义务漏网鲨鱼
发布2025-09-14 16:52:45
发布2025-09-14 16:52:45
12100
代码可运行
举报
文章被收录于专栏:tencent cloudtencent cloud
运行总次数:0
代码可运行

“device-side assert triggered” :含分割、多标签等易混场景

在数据量较大时,选用较大的batch通常会出现显存溢出的情况。但是,除此之外,PyTorch 训练途中在 GPU上突然报错:CUDA error: device-side assert triggered,接着所有 CUDA 调用连环报错;在CPU 上却一切正常。

❓ Bug 现象

  • 训练到某个 batch 突然爆:RuntimeError: CUDA error: device-side assert triggered
  • 从此之后所有 CUDA 调用都失败(哪怕换一行代码也报错),必须重启进程。
  • 同一数据在 CPU 上跑 CrossEntropyLoss 没报错(因为 CPU 实现会抛 Python 异常;GPU 端断言由内核触发更晚)。

📽️ 场景复现

1️⃣ 分类任务

代码语言:python
代码运行次数:0
运行
复制
import torch, torch.nn.functional as F
B, C = 4, 3
logits = torch.randn(B, C, device="cuda")         # [B,C]
target = torch.tensor([0, 1, 3, 2], device="cuda")# ❌ 类别“3”越界(C=3 → 0..2)
# 或 target = torch.tensor([[1,0,0],[0,1,0],[0,0,1],[0,0,1]], device="cuda", dtype=torch.float32)  # ❌ one-hot

loss = F.cross_entropy(logits, target)            # 💥 device-side assert triggered

2️⃣ 语义分割

代码语言:python
代码运行次数:0
运行
复制
B, C, H, W = 2, 4, 256, 256
logits = torch.randn(B, C, H, W, device="cuda")   # [B,C,H,W]
target = torch.randint(0, C, (B, H, W, 1), device="cuda")  # ❌ 多了个通道维
F.cross_entropy(logits, target.squeeze(-1).float())        # ❌ float + squeeze 乱搞
为什么 GPU 会炸?

CrossEntropyLoss 的核心是沿着类别维度 对 logitsgather:因此需要目标是整型类别索引(Long)同时,每个索引都必须落在 0, C-1。对于分割,目标形状必须是B,H,W(或更多空间维),不能带 C 维、不能 one-hot

GPU 内核遇到越界或非法 dtype 时会触发device-side assert,报错滞后且信息少。

Debug过程

1️⃣ 先把错误“同步到行”

加这个环境变量,让报错定位回 Python 栈(只用于排查):

代码语言:python
代码运行次数:0
运行
复制
export CUDA_LAUNCH_BLOCKING=1

这样可以拿到哪一行 cross_entropy 触发了异常。

2️⃣ 对 logits/target 做强约束断言

分类:

代码语言:python
代码运行次数:0
运行
复制
def guard_ce_shapes(logits, target):
    # logits: [B,C], target: [B](Long, 0..C-1)
    assert logits.dim()==2, f"logits shape {logits.shape}"
    assert target.dim()==1 and target.size(0)==logits.size(0), f"target shape {target.shape}"
    assert target.dtype==torch.long, f"target dtype={target.dtype}, need Long"
    C = logits.size(1)
    bad = ~((0 <= target) & (target < C))
    if bad.any():
        idx = bad.nonzero(as_tuple=True)[0][:8].tolist()
        raise ValueError(f"target out of range at idx {idx}, C={C}, vals={target[idx].tolist()}")

分割:

代码语言:python
代码运行次数:0
运行
复制
def guard_seg_shapes(logits, target):
    # logits: [B,C,H,W], target: [B,H,W] (Long, 0..C-1)
    assert logits.dim()==4, f"logits {logits.shape}"
    B,C,H,W = logits.shape
    assert target.shape==(B,H,W), f"target {target.shape} expect {(B,H,W)}"
    assert target.dtype==torch.long
    bad = ~((0 <= target) & (target < C))
    if bad.any(): raise ValueError("mask contains invalid class id")
定位可能原因
  • 在Dataset与collate_fn里打印 dtype/shape/min/max
  • 检查数据增强** 是否把整数 mask 变成了 float(例如 albumentations 某些插值把标签当图像处理);
  • 检查CutMix/Mixup:它们产生soft label(浮点),不能再用 CrossEntropyLoss,要么改用 BCEWithLogitsLoss(多标签思路),要么用 SoftTargetCrossEntropy(自定义)。

修改代码

1️⃣ 分类

代码语言:python
代码运行次数:0
运行
复制
# 正确:labels 为 [B] 的 Long,值域 [0,C-1]
logits = head(x)                 # [B,C]
labels = labels.long()           # 确保 Long
loss = F.cross_entropy(logits, labels)

2️⃣ 语义分割

代码语言:python
代码运行次数:0
运行
复制
# 正确:logits [B,C,H,W];mask [B,H,W] Long(0..C-1)
loss = F.cross_entropy(logits, mask.long(), ignore_index=255)  # 例:把 255 当无效像素

注意:对 mask 做 resize 时用最近邻

代码语言:python
代码运行次数:0
运行
复制
mask = F.interpolate(mask.unsqueeze(1).float(), size=(H,W), mode="nearest").squeeze(1).long()

3️⃣ CutMix / Mixup(软标签)

  • 方案 A:Soft CE
代码语言:python
代码运行次数:0
运行
复制
def soft_cross_entropy(logits, soft_targets):
    # logits: [B,C], soft_targets: [B,C] (sum=1)
    log_prob = F.log_softmax(logits, dim=-1)
    loss = -(soft_targets * log_prob).sum(dim=-1).mean()
    return loss
  • 方案 B:BCEWithLogitsLoss(把每类看作独立二分类;与多标签一致)

4️⃣ 多标签分类(C independent)

代码语言:python
代码运行次数:0
运行
复制
# 多标签:logits [B,C],targets [B,C] ∈ {0,1}(或 [0,1] 概率)
loss = F.binary_cross_entropy_with_logits(logits, targets.float())

验证

  • 修正 dtype/形状/取值后,训练不再触发 device assert;
  • 分割任务把 mask 插值改为最近邻后,mIoU 提升 2~5pp(原来软化了边界);
  • 使用 Soft CE 处理 CutMix/Mixup,收敛更平滑、精度不再异常抖动。

结语

“device-side assert triggered” 多半不是 CUDA 的锅,而是目标张量与损失约定不匹配。 把 dtype / 形状 / 取值域 三件事放在数据流入口就做硬断言,再配一套 soft label/segmentation 的专用路径,这类大杀器级别的“黑盒错误”就能在 5 分钟内被钉死。最终定位是 CrossEntropyLoss 的目标张量(target)不符合约定:要么 dtype 不是 Long,要么 取值不在 0, C-1,要么任务类型错用损失(把 one-hot / 多标签当成单标签)。本文给出复现 → 排查 → 修复的完整记录,并覆盖图像分割、CutMix/Mixup、ignore_index、类权重等高频坑。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • “device-side assert triggered” :含分割、多标签等易混场景
    • ❓ Bug 现象
    • 📽️ 场景复现
    • Debug过程
    • 修改代码
    • 验证
    • 结语
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档