GitHub 仓库地址:https://github.com/yanqiangmiffy/amp-pytorch
Pytorch自动混合精度训练模板
使用 pytorch 的自动混合精度教程。基于 PyTorch 1.6 Official Features (Automatic Mixed Precision) ,使用自定义数据集实现分类代码库
1.1 环境设置
- Pytorch>=1.6.0 支持CUDA
- 本地实验环境为:RTX 3090 24G
1.2 下载数据集 Kaggle 英特尔图像分类 数据集地址:Intel Image Classification
该数据包含大约 25k 张大小为 150x150 的图像,分布在 6 个类别下。 {‘建筑物’ -> 0, ‘森林’ -> 1, ‘冰川’ -> 2, ‘山’ -> 3, ‘海’ -> 4, ‘街道’ -> 5 }
数据集解压直接放在data目录下
data:
seg_pred
seg_test
seg_train
实验设置:
运行命令
python main.py --checkpoint_name baseline
未使用混合精度训练:
for batch_idx, (inputs, labels) in enumerate(data_loader):
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
使用混合精度训练
scaler = torch.cuda.amp.GradScaler()
for batch_idx, (inputs, labels) in enumerate(data_loader):
self.optimizer.zero_grad()
with torch.cuda.amp.autocast():
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
# Scales the loss, and calls backward()
# to create scaled gradients
self.scaler.scale(loss).backward()
# Unscales gradients and calls
# or skips optimizer.step()
self.scaler.step(self.optimizer)
# Updates the scale for next iteration
self.scaler.update()
运行命令
python main.py --checkpoint_name baseline_amp --amp;
B : Baseline (FP32) AMP : Automatic Mixed Precision Training (AMP)
Algorithm | Test Accuracy | GPU Memory | Total Training Time |
---|---|---|---|
B - 3090 Ti | 94.17 | 13.0G | (44s*20epochs)~=15mins |
AMP - 3090 Ti | 94.23 | 10.6G | (33s*20eochs)~=11mins |
代码主要来自该仓库,作者实现以及项目代码很完善了,主要将其拆分出来,可以灵活应用