在数据量较大时,选用较大的batch通常会出现显存溢出的情况。但是,除此之外,PyTorch 训练途中在 GPU上突然报错:
CUDA error: device-side assert triggered
,接着所有 CUDA 调用连环报错;在CPU 上却一切正常。
RuntimeError: CUDA error: device-side assert triggered
。1️⃣ 分类任务
import torch, torch.nn.functional as F
B, C = 4, 3
logits = torch.randn(B, C, device="cuda") # [B,C]
target = torch.tensor([0, 1, 3, 2], device="cuda")# ❌ 类别“3”越界(C=3 → 0..2)
# 或 target = torch.tensor([[1,0,0],[0,1,0],[0,0,1],[0,0,1]], device="cuda", dtype=torch.float32) # ❌ one-hot
loss = F.cross_entropy(logits, target) # 💥 device-side assert triggered
2️⃣ 语义分割
B, C, H, W = 2, 4, 256, 256
logits = torch.randn(B, C, H, W, device="cuda") # [B,C,H,W]
target = torch.randint(0, C, (B, H, W, 1), device="cuda") # ❌ 多了个通道维
F.cross_entropy(logits, target.squeeze(-1).float()) # ❌ float + squeeze 乱搞
CrossEntropyLoss
的核心是沿着类别维度 对 logits
做 gather
:因此需要目标是整型类别索引(Long
)同时,每个索引都必须落在 0, C-1。对于分割,目标形状必须是B,H,W(或更多空间维),不能带 C 维、不能 one-hot
GPU 内核遇到越界或非法 dtype 时会触发device-side assert,报错滞后且信息少。
1️⃣ 先把错误“同步到行”
加这个环境变量,让报错定位回 Python 栈(只用于排查):
export CUDA_LAUNCH_BLOCKING=1
这样可以拿到哪一行
cross_entropy
触发了异常。
2️⃣ 对 logits/target 做强约束断言
分类:
def guard_ce_shapes(logits, target):
# logits: [B,C], target: [B](Long, 0..C-1)
assert logits.dim()==2, f"logits shape {logits.shape}"
assert target.dim()==1 and target.size(0)==logits.size(0), f"target shape {target.shape}"
assert target.dtype==torch.long, f"target dtype={target.dtype}, need Long"
C = logits.size(1)
bad = ~((0 <= target) & (target < C))
if bad.any():
idx = bad.nonzero(as_tuple=True)[0][:8].tolist()
raise ValueError(f"target out of range at idx {idx}, C={C}, vals={target[idx].tolist()}")
分割:
def guard_seg_shapes(logits, target):
# logits: [B,C,H,W], target: [B,H,W] (Long, 0..C-1)
assert logits.dim()==4, f"logits {logits.shape}"
B,C,H,W = logits.shape
assert target.shape==(B,H,W), f"target {target.shape} expect {(B,H,W)}"
assert target.dtype==torch.long
bad = ~((0 <= target) & (target < C))
if bad.any(): raise ValueError("mask contains invalid class id")
dtype/shape/min/max
;albumentations
某些插值把标签当图像处理);CrossEntropyLoss
,要么改用 BCEWithLogitsLoss
(多标签思路),要么用 SoftTargetCrossEntropy
(自定义)。1️⃣ 分类
# 正确:labels 为 [B] 的 Long,值域 [0,C-1]
logits = head(x) # [B,C]
labels = labels.long() # 确保 Long
loss = F.cross_entropy(logits, labels)
2️⃣ 语义分割
# 正确:logits [B,C,H,W];mask [B,H,W] Long(0..C-1)
loss = F.cross_entropy(logits, mask.long(), ignore_index=255) # 例:把 255 当无效像素
注意:对 mask 做 resize 时用最近邻:
mask = F.interpolate(mask.unsqueeze(1).float(), size=(H,W), mode="nearest").squeeze(1).long()
3️⃣ CutMix / Mixup(软标签)
def soft_cross_entropy(logits, soft_targets):
# logits: [B,C], soft_targets: [B,C] (sum=1)
log_prob = F.log_softmax(logits, dim=-1)
loss = -(soft_targets * log_prob).sum(dim=-1).mean()
return loss
4️⃣ 多标签分类(C independent)
# 多标签:logits [B,C],targets [B,C] ∈ {0,1}(或 [0,1] 概率)
loss = F.binary_cross_entropy_with_logits(logits, targets.float())
“device-side assert triggered” 多半不是 CUDA 的锅,而是目标张量与损失约定不匹配。 把 dtype / 形状 / 取值域 三件事放在数据流入口就做硬断言,再配一套 soft label/segmentation 的专用路径,这类大杀器级别的“黑盒错误”就能在 5 分钟内被钉死。最终定位是 CrossEntropyLoss 的目标张量(target)不符合约定:要么 dtype 不是 Long
,要么 取值不在 0, C-1,要么任务类型错用损失(把 one-hot / 多标签当成单标签)。本文给出复现 → 排查 → 修复的完整记录,并覆盖图像分割、CutMix/Mixup、ignore_index、类权重等高频坑。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。