前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >实操教程|PyTorch实现断点继续训练

实操教程|PyTorch实现断点继续训练

作者头像
小白学视觉
发布2024-11-13 18:53:35
320
发布2024-11-13 18:53:35
举报
文章被收录于专栏:深度学习和计算机视觉

导读

本文整理了pytorch实现断电继续训练时需要注意的要点,附有代码详解。

最近在尝试用CIFAR10训练分类问题的时候,由于数据集体量比较大,训练的过程中时间比较长,有时候想给停下来,但是停下来了之后就得重新训练,之前师兄让我们学习断点继续训练及继续训练的时候注意epoch的改变等,今天上午给大致整理了一下,不全面仅供参考

代码语言:javascript
复制
Epoch:  9 | train loss: 0.3517 | test accuracy: 0.7184 | train time: 14215.1018  s
Epoch:  9 | train loss: 0.2471 | test accuracy: 0.7252 | train time: 14309.1216  s
Epoch:  9 | train loss: 0.4335 | test accuracy: 0.7201 | train time: 14403.2398  s
Epoch:  9 | train loss: 0.2186 | test accuracy: 0.7242 | train time: 14497.1921  s
Epoch:  9 | train loss: 0.2127 | test accuracy: 0.7196 | train time: 14591.4974  s
Epoch:  9 | train loss: 0.1624 | test accuracy: 0.7142 | train time: 14685.7034  s
Epoch:  9 | train loss: 0.1795 | test accuracy: 0.7170 | train time: 14780.2831  s
绝望!!!!!训练到了一定次数发现训练次数少了,或者中途断了又得重新开始训练

一、模型的保存与加载

PyTorch中的保存(序列化,从内存到硬盘)与反序列化(加载,从硬盘到内存)

torch.save主要参数:obj:对象 、f:输出路径

torch.load 主要参数 :f:文件路径 、map_location:指定存放位置、 cpu or gpu

模型的保存的两种方法:

1、保存整个Module

代码语言:javascript
复制
torch.save(net, path)

2、保存模型参数

代码语言:javascript
复制
state_dict = net.state_dict()
torch.save(state_dict , path)

二、模型的训练过程中保存

代码语言:javascript
复制
checkpoint = {
        "net": model.state_dict(),
        'optimizer':optimizer.state_dict(),
        "epoch": epoch
    }

将网络训练过程中的网络的权重,优化器的权重保存,以及epoch 保存,便于继续训练恢复

在训练过程中,可以根据自己的需要,每多少代,或者多少epoch保存一次网络参数,便于恢复,提高程序的鲁棒性。

代码语言:javascript
复制
checkpoint = {
        "net": model.state_dict(),
        'optimizer':optimizer.state_dict(),
        "epoch": epoch
    }
    if not os.path.isdir("./models/checkpoint"):
        os.mkdir("./models/checkpoint")
    torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))
代码语言:javascript
复制
通过上述的过程可以在训练过程自动在指定位置创建文件夹,并保存断点文件

三、模型的断点继续训练

代码语言:javascript
复制
if RESUME:
    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点

    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    start_epoch = checkpoint['epoch']  # 设置开始的epoch

指出这里的是否继续训练,及训练的checkpoint的文件位置等可以通过argparse从命令行直接读取,也可以通过log文件直接加载,也可以自己在代码中进行修改。关于argparse参照我的这一篇文章:

HUST小菜鸡:argparse 命令行选项、参数和子命令解析器

https://zhuanlan.zhihu.com/p/133285373

四、重点在于epoch的恢复

代码语言:javascript
复制
start_epoch = -1


if RESUME:
    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点

    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    start_epoch = checkpoint['epoch']  # 设置开始的epoch



for epoch in  range(start_epoch + 1 ,EPOCH):
    # print('EPOCH:',epoch)
    for step, (b_img,b_label) in enumerate(train_loader):
        train_output = model(b_img)
        loss = loss_func(train_output,b_label)
        # losses.append(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

通过定义start_epoch变量来保证继续训练的时候epoch不会变化

断点继续训练

一、初始化随机数种子

代码语言:javascript
复制
import torch
import random
import numpy as np

def set_random_seed(seed = 10,deterministic=False,benchmark=False):
    random.seed(seed)
    np.random(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.deterministic = True
    if benchmark:
        torch.backends.cudnn.benchmark = True

关于torch.backends.cudnn.deterministic和torch.backends.cudnn.benchmark详见

Pytorch学习0.01:cudnn.benchmark= True的设置

https://www.cnblogs.com/captain-dl/p/11938864.html

pytorch---之cudnn.benchmark和cudnn.deterministic_人工智能_zxyhhjs2017的博客

https://blog.csdn.net/zxyhhjs2017/article/details/91348108

benchmark用在输入尺寸一致,可以加速训练,deterministic用来固定内部随机性

二、多步长SGD继续训练

在简单的任务中,我们使用固定步长(也就是学习率LR)进行训练,但是如果学习率lr设置的过小的话,则会导致很难收敛,如果学习率很大的时候,就会导致在最小值附近,总会错过最小值,loss产生震荡,无法收敛。所以这要求我们要对于不同的训练阶段使用不同的学习率,一方面可以加快训练的过程,另一方面可以加快网络收敛。

采用多步长 torch.optim.lr_scheduler的多种步长设置方式来实现步长的控制,lr_scheduler的各种使用推荐参考如下教程:

【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

https://www.cnblogs.com/devilmaycry812839668/p/10630302.html

所以我们在保存网络中的训练的参数的过程中,还需要保存lr_scheduler的state_dict,然后断点继续训练的时候恢复

代码语言:javascript
复制
#这里我设置了不同的epoch对应不同的学习率衰减,在10->20->30,学习率依次衰减为原来的0.1,即一个数量级
lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40,50],gamma=0.1)
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)

for epoch in range(start_epoch+1,80):
    optimizer.zero_grad()
    optimizer.step()
    lr_schedule.step()

    if epoch %10 ==0:
        print('epoch:',epoch)
        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
代码语言:javascript
复制
lr的变化过程如下:
代码语言:javascript
复制
epoch: 10
learning rate: 0.1
epoch: 20
learning rate: 0.010000000000000002
epoch: 30
learning rate: 0.0010000000000000002
epoch: 40
learning rate: 0.00010000000000000003
epoch: 50
learning rate: 1.0000000000000004e-05
epoch: 60
learning rate: 1.0000000000000004e-06
epoch: 70
learning rate: 1.0000000000000004e-06

我们在保存的时候,也需要对lr_scheduler的state_dict进行保存,断点继续训练的时候也需要恢复lr_scheduler

代码语言:javascript
复制
#加载恢复
if RESUME:
    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 断点路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点

    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    start_epoch = checkpoint['epoch']  # 设置开始的epoch
    lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler



#保存
for epoch in range(start_epoch+1,80):

    optimizer.zero_grad()

    optimizer.step()
    lr_schedule.step()


    if epoch %10 ==0:
        print('epoch:',epoch)
        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
        checkpoint = {
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            "epoch": epoch,
            'lr_schedule': lr_schedule.state_dict()
        }
        if not os.path.isdir("./model_parameter/test"):
            os.mkdir("./model_parameter/test")
        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))

三、保存最好的结果

每一个epoch中的每个step会有不同的结果,可以保存每一代最好的结果,用于后续的训练

第一次实验代码

代码语言:javascript
复制
RESUME = True

EPOCH = 40
LR = 0.0005


model = cifar10_cnn.CIFAR10_CNN()

print(model)
optimizer = torch.optim.Adam(model.parameters(),lr=LR)
loss_func = nn.CrossEntropyLoss()

start_epoch = -1


if RESUME:
    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点

    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    start_epoch = checkpoint['epoch']  # 设置开始的epoch



for epoch in  range(start_epoch + 1 ,EPOCH):
    # print('EPOCH:',epoch)
    for step, (b_img,b_label) in enumerate(train_loader):
        train_output = model(b_img)
        loss = loss_func(train_output,b_label)
        # losses.append(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            now = time.time()
            print('EPOCH:',epoch,'| step :',step,'| loss :',loss.data.numpy(),'| train time: %.4f'%(now-start_time))

    checkpoint = {
        "net": model.state_dict(),
        'optimizer':optimizer.state_dict(),
        "epoch": epoch
    }
    if not os.path.isdir("./models/checkpoint"):
        os.mkdir("./models/checkpoint")
    torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))
代码语言:javascript
复制

更新实验代码

代码语言:javascript
复制
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40,50],gamma=0.1)
start_epoch = 9
# print(schedule)


if RESUME:
    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 断点路径
    checkpoint = torch.load(path_checkpoint)  # 加载断点

    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数

    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数
    start_epoch = checkpoint['epoch']  # 设置开始的epoch
    lr_schedule.load_state_dict(checkpoint['lr_schedule'])

for epoch in range(start_epoch+1,80):

    optimizer.zero_grad()

    optimizer.step()
    lr_schedule.step()


    if epoch %10 ==0:
        print('epoch:',epoch)
        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
        checkpoint = {
            "net": model.state_dict(),
            'optimizer': optimizer.state_dict(),
            "epoch": epoch,
            'lr_schedule': lr_schedule.state_dict()
        }
        if not os.path.isdir("./model_parameter/test"):
            os.mkdir("./model_parameter/test")
        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2024-11-13,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 小白学视觉 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、模型的保存与加载
    • 1、保存整个Module
      • 2、保存模型参数
      • 二、模型的训练过程中保存
      • 三、模型的断点继续训练
      • 四、重点在于epoch的恢复
      • 一、初始化随机数种子
        • 二、多步长SGD继续训练
        • 三、保存最好的结果
          • 第一次实验代码
          • 更新实验代码
          相关产品与服务
          人工智能与机器学习
          提供全球领先的人脸识别、文字识别、图像识别、语音技术、NLP、人工智能服务平台等多项人工智能技术,共享 AI 领域应用场景和解决方案。
          领券
          问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档